diff --git a/backend/danswer/chat/answer.py b/backend/danswer/chat/answer.py index d2db03186c9..578529364cd 100644 --- a/backend/danswer/chat/answer.py +++ b/backend/danswer/chat/answer.py @@ -206,7 +206,9 @@ class Answer: # + figure out what the next LLM call should be tool_call_handler = ToolResponseHandler(current_llm_call.tools) - search_result = SearchTool.get_search_result(current_llm_call) or [] + search_result, displayed_search_results_map = SearchTool.get_search_result( + current_llm_call + ) or ([], {}) # Quotes are no longer supported # answer_handler: AnswerResponseHandler @@ -224,6 +226,7 @@ class Answer: answer_handler = CitationResponseHandler( context_docs=search_result, doc_id_to_rank_map=map_document_id_order(search_result), + display_doc_order_dict=displayed_search_results_map, ) response_handler_manager = LLMResponseHandlerManager( diff --git a/backend/danswer/chat/stream_processing/answer_response_handler.py b/backend/danswer/chat/stream_processing/answer_response_handler.py index 8a8bda40d9d..a10f46be5f5 100644 --- a/backend/danswer/chat/stream_processing/answer_response_handler.py +++ b/backend/danswer/chat/stream_processing/answer_response_handler.py @@ -35,13 +35,18 @@ class DummyAnswerResponseHandler(AnswerResponseHandler): class CitationResponseHandler(AnswerResponseHandler): def __init__( - self, context_docs: list[LlmDoc], doc_id_to_rank_map: DocumentIdOrderMapping + self, + context_docs: list[LlmDoc], + doc_id_to_rank_map: DocumentIdOrderMapping, + display_doc_order_dict: dict[str, int], ): self.context_docs = context_docs self.doc_id_to_rank_map = doc_id_to_rank_map + self.display_doc_order_dict = display_doc_order_dict self.citation_processor = CitationProcessor( context_docs=self.context_docs, doc_id_to_rank_map=self.doc_id_to_rank_map, + display_doc_order_dict=self.display_doc_order_dict, ) self.processed_text = "" self.citations: list[CitationInfo] = [] diff --git a/backend/danswer/chat/stream_processing/citation_processing.py b/backend/danswer/chat/stream_processing/citation_processing.py index 5a50855e98c..8966303faff 100644 --- a/backend/danswer/chat/stream_processing/citation_processing.py +++ b/backend/danswer/chat/stream_processing/citation_processing.py @@ -22,12 +22,16 @@ class CitationProcessor: self, context_docs: list[LlmDoc], doc_id_to_rank_map: DocumentIdOrderMapping, + display_doc_order_dict: dict[str, int], stop_stream: str | None = STOP_STREAM_PAT, ): self.context_docs = context_docs self.doc_id_to_rank_map = doc_id_to_rank_map self.stop_stream = stop_stream self.order_mapping = doc_id_to_rank_map.order_mapping + self.display_doc_order_dict = ( + display_doc_order_dict # original order of docs to displayed to user + ) self.llm_out = "" self.max_citation_num = len(context_docs) self.citation_order: list[int] = [] @@ -98,6 +102,18 @@ class CitationProcessor: self.citation_order.index(real_citation_num) + 1 ) + # get the value that was displayed to user, should always + # be in the display_doc_order_dict. But check anyways + if context_llm_doc.document_id in self.display_doc_order_dict: + displayed_citation_num = self.display_doc_order_dict[ + context_llm_doc.document_id + ] + else: + displayed_citation_num = real_citation_num + logger.warning( + f"Doc {context_llm_doc.document_id} not in display_doc_order_dict. Used LLM citation number instead." + ) + # Skip consecutive citations of the same work if target_citation_num in self.current_citations: start, end = citation.span() @@ -118,6 +134,7 @@ class CitationProcessor: doc_id = int(match.group(1)) context_llm_doc = self.context_docs[doc_id - 1] yield CitationInfo( + # stay with the original for now (order of LLM cites) citation_num=target_citation_num, document_id=context_llm_doc.document_id, ) @@ -139,6 +156,7 @@ class CitationProcessor: if target_citation_num not in self.cited_inds: self.cited_inds.add(target_citation_num) yield CitationInfo( + # stay with the original for now (order of LLM cites) citation_num=target_citation_num, document_id=context_llm_doc.document_id, ) @@ -148,7 +166,8 @@ class CitationProcessor: prev_length = len(self.curr_segment) self.curr_segment = ( self.curr_segment[: start + length_to_add] - + f"[[{target_citation_num}]]({link})" + + f"[[{displayed_citation_num}]]({link})" # use the value that was displayed to user + # + f"[[{target_citation_num}]]({link})" + self.curr_segment[end + length_to_add :] ) length_to_add += len(self.curr_segment) - prev_length @@ -156,7 +175,8 @@ class CitationProcessor: prev_length = len(self.curr_segment) self.curr_segment = ( self.curr_segment[: start + length_to_add] - + f"[[{target_citation_num}]]()" + + f"[[{displayed_citation_num}]]()" # use the value that was displayed to user + # + f"[[{target_citation_num}]]()" + self.curr_segment[end + length_to_add :] ) length_to_add += len(self.curr_segment) - prev_length diff --git a/backend/danswer/tools/tool_implementations/search/search_tool.py b/backend/danswer/tools/tool_implementations/search/search_tool.py index 5bf08e564e2..a0c686bd6cf 100644 --- a/backend/danswer/tools/tool_implementations/search/search_tool.py +++ b/backend/danswer/tools/tool_implementations/search/search_tool.py @@ -48,6 +48,9 @@ from danswer.tools.tool_implementations.search_like_tool_utils import ( from danswer.tools.tool_implementations.search_like_tool_utils import ( FINAL_CONTEXT_DOCUMENTS_ID, ) +from danswer.tools.tool_implementations.search_like_tool_utils import ( + ORIGINAL_CONTEXT_DOCUMENTS_ID, +) from danswer.utils.logger import setup_logger from danswer.utils.special_types import JSON_ro @@ -391,15 +394,35 @@ class SearchTool(Tool): """Other utility functions""" @classmethod - def get_search_result(cls, llm_call: LLMCall) -> list[LlmDoc] | None: + def get_search_result( + cls, llm_call: LLMCall + ) -> tuple[list[LlmDoc], dict[str, int]] | None: + """ + Returns the final search results and a map of docs to their original search rank (which is what is displayed to user) + """ if not llm_call.tool_call_info: return None + final_search_results = [] + doc_id_to_original_search_rank_map = {} + for yield_item in llm_call.tool_call_info: if ( isinstance(yield_item, ToolResponse) and yield_item.id == FINAL_CONTEXT_DOCUMENTS_ID ): - return cast(list[LlmDoc], yield_item.response) + final_search_results = cast(list[LlmDoc], yield_item.response) + elif ( + isinstance(yield_item, ToolResponse) + and yield_item.id == ORIGINAL_CONTEXT_DOCUMENTS_ID + ): + search_contexts = yield_item.response.contexts + original_doc_search_rank = 1 + for idx, doc in enumerate(search_contexts): + if doc.document_id not in doc_id_to_original_search_rank_map: + doc_id_to_original_search_rank_map[ + doc.document_id + ] = original_doc_search_rank + original_doc_search_rank += 1 - return None + return final_search_results, doc_id_to_original_search_rank_map diff --git a/backend/danswer/tools/tool_implementations/search_like_tool_utils.py b/backend/danswer/tools/tool_implementations/search_like_tool_utils.py index 761e2f9eccb..7edb22fc144 100644 --- a/backend/danswer/tools/tool_implementations/search_like_tool_utils.py +++ b/backend/danswer/tools/tool_implementations/search_like_tool_utils.py @@ -15,6 +15,7 @@ from danswer.tools.message import ToolCallSummary from danswer.tools.models import ToolResponse +ORIGINAL_CONTEXT_DOCUMENTS_ID = "search_doc_content" FINAL_CONTEXT_DOCUMENTS_ID = "final_context_documents" diff --git a/backend/tests/unit/danswer/chat/stream_processing/test_citation_processing.py b/backend/tests/unit/danswer/chat/stream_processing/test_citation_processing.py index 563e26780fb..178240c7176 100644 --- a/backend/tests/unit/danswer/chat/stream_processing/test_citation_processing.py +++ b/backend/tests/unit/danswer/chat/stream_processing/test_citation_processing.py @@ -72,8 +72,10 @@ def process_text( processor = CitationProcessor( context_docs=mock_docs, doc_id_to_rank_map=mapping, + display_doc_order_dict=mock_doc_id_to_rank_map, stop_stream=None, ) + result: list[DanswerAnswerPiece | CitationInfo] = [] for token in tokens: result.extend(processor.process_token(token)) @@ -86,6 +88,7 @@ def process_text( final_answer_text += piece.answer_piece or "" elif isinstance(piece, CitationInfo): citations.append(piece) + return final_answer_text, citations diff --git a/backend/tests/unit/danswer/chat/stream_processing/test_citation_substitution.py b/backend/tests/unit/danswer/chat/stream_processing/test_citation_substitution.py new file mode 100644 index 00000000000..841d76a3247 --- /dev/null +++ b/backend/tests/unit/danswer/chat/stream_processing/test_citation_substitution.py @@ -0,0 +1,132 @@ +from datetime import datetime + +import pytest + +from danswer.chat.models import CitationInfo +from danswer.chat.models import DanswerAnswerPiece +from danswer.chat.models import LlmDoc +from danswer.chat.stream_processing.citation_processing import CitationProcessor +from danswer.chat.stream_processing.utils import DocumentIdOrderMapping +from danswer.configs.constants import DocumentSource + + +""" +This module contains tests for the citation extraction functionality in Danswer, +specifically the substitution of the number of document cited in the UI. (The LLM +will see the sources post re-ranking and relevance check, the UI before these steps.) +This module is a derivative of test_citation_processing.py. + +The tests focusses specifically on the substitution of the number of document cited in the UI. + +Key components: +- mock_docs: A list of mock LlmDoc objects used for testing. +- mock_doc_mapping: A dictionary mapping document IDs to their initial ranks. +- mock_doc_mapping_rerank: A dictionary mapping document IDs to their ranks after re-ranking/relevance check. +- process_text: A helper function that simulates the citation extraction process. +- test_citation_extraction: A parametrized test function covering various citation scenarios. + +To add new test cases: +1. Add a new tuple to the @pytest.mark.parametrize decorator of test_citation_extraction. +2. Each tuple should contain: + - A descriptive test name (string) + - Input tokens (list of strings) + - Expected output text (string) + - Expected citations (list of document IDs) +""" + + +mock_docs = [ + LlmDoc( + document_id=f"doc_{int(id/2)}", + content="Document is a doc", + blurb=f"Document #{id}", + semantic_identifier=f"Doc {id}", + source_type=DocumentSource.WEB, + metadata={}, + updated_at=datetime.now(), + link=f"https://{int(id/2)}.com" if int(id / 2) % 2 == 0 else None, + source_links={0: "https://mintlify.com/docs/settings/broken-links"}, + match_highlights=[], + ) + for id in range(10) +] + +mock_doc_mapping = { + "doc_0": 1, + "doc_1": 2, + "doc_2": 3, + "doc_3": 4, + "doc_4": 5, + "doc_5": 6, +} + +mock_doc_mapping_rerank = { + "doc_0": 2, + "doc_1": 1, + "doc_2": 4, + "doc_3": 3, + "doc_4": 6, + "doc_5": 5, +} + + +@pytest.fixture +def mock_data() -> tuple[list[LlmDoc], dict[str, int]]: + return mock_docs, mock_doc_mapping + + +def process_text( + tokens: list[str], mock_data: tuple[list[LlmDoc], dict[str, int]] +) -> tuple[str, list[CitationInfo]]: + mock_docs, mock_doc_id_to_rank_map = mock_data + mapping = DocumentIdOrderMapping(order_mapping=mock_doc_id_to_rank_map) + processor = CitationProcessor( + context_docs=mock_docs, + doc_id_to_rank_map=mapping, + display_doc_order_dict=mock_doc_mapping_rerank, + stop_stream=None, + ) + + result: list[DanswerAnswerPiece | CitationInfo] = [] + for token in tokens: + result.extend(processor.process_token(token)) + result.extend(processor.process_token(None)) + + final_answer_text = "" + citations = [] + for piece in result: + if isinstance(piece, DanswerAnswerPiece): + final_answer_text += piece.answer_piece or "" + elif isinstance(piece, CitationInfo): + citations.append(piece) + + return final_answer_text, citations + + +@pytest.mark.parametrize( + "test_name, input_tokens, expected_text, expected_citations", + [ + ( + "Single citation", + ["Gro", "wth! [", "1", "]", "."], + "Growth! [[2]](https://0.com).", + ["doc_0"], + ), + ], +) +def test_citation_substitution( + mock_data: tuple[list[LlmDoc], dict[str, int]], + test_name: str, + input_tokens: list[str], + expected_text: str, + expected_citations: list[str], +) -> None: + final_answer_text, citations = process_text(input_tokens, mock_data) + assert ( + final_answer_text.strip() == expected_text.strip() + ), f"Test '{test_name}' failed: Final answer text does not match expected output." + assert [ + citation.document_id for citation in citations + ] == expected_citations, ( + f"Test '{test_name}' failed: Citations do not match expected output." + )