mirror of
https://github.com/danswer-ai/danswer.git
synced 2025-07-28 05:43:33 +02:00
Mismatch issue of Documents shown and Citation number in text fix (#3421)
* Mismatch issue of Documents shown and Citation number in text fix When document order presented to LLM differs from order shown to user, wrong doc numbers are cited. Fix: - SearchTool.get_search_result returns now final and initial ranking - initial ranking is passed through a few objects and used for replacement in citation processing Notes: - the citation_num in the CitationInfo() object has not been changed. * PR fixes - linting - removed erroneous tab - added a substitution test case - adjusted original citation extraction use case * Included a key test and * Fixed extra spaces * Updated test documentation Updated: - test_citation_substitution (changed description) - test_citation_processing (removed data only relevant for the substitution)
This commit is contained in:
@@ -206,7 +206,9 @@ class Answer:
|
||||
# + figure out what the next LLM call should be
|
||||
tool_call_handler = ToolResponseHandler(current_llm_call.tools)
|
||||
|
||||
search_result = SearchTool.get_search_result(current_llm_call) or []
|
||||
search_result, displayed_search_results_map = SearchTool.get_search_result(
|
||||
current_llm_call
|
||||
) or ([], {})
|
||||
|
||||
# Quotes are no longer supported
|
||||
# answer_handler: AnswerResponseHandler
|
||||
@@ -224,6 +226,7 @@ class Answer:
|
||||
answer_handler = CitationResponseHandler(
|
||||
context_docs=search_result,
|
||||
doc_id_to_rank_map=map_document_id_order(search_result),
|
||||
display_doc_order_dict=displayed_search_results_map,
|
||||
)
|
||||
|
||||
response_handler_manager = LLMResponseHandlerManager(
|
||||
|
@@ -35,13 +35,18 @@ class DummyAnswerResponseHandler(AnswerResponseHandler):
|
||||
|
||||
class CitationResponseHandler(AnswerResponseHandler):
|
||||
def __init__(
|
||||
self, context_docs: list[LlmDoc], doc_id_to_rank_map: DocumentIdOrderMapping
|
||||
self,
|
||||
context_docs: list[LlmDoc],
|
||||
doc_id_to_rank_map: DocumentIdOrderMapping,
|
||||
display_doc_order_dict: dict[str, int],
|
||||
):
|
||||
self.context_docs = context_docs
|
||||
self.doc_id_to_rank_map = doc_id_to_rank_map
|
||||
self.display_doc_order_dict = display_doc_order_dict
|
||||
self.citation_processor = CitationProcessor(
|
||||
context_docs=self.context_docs,
|
||||
doc_id_to_rank_map=self.doc_id_to_rank_map,
|
||||
display_doc_order_dict=self.display_doc_order_dict,
|
||||
)
|
||||
self.processed_text = ""
|
||||
self.citations: list[CitationInfo] = []
|
||||
|
@@ -22,12 +22,16 @@ class CitationProcessor:
|
||||
self,
|
||||
context_docs: list[LlmDoc],
|
||||
doc_id_to_rank_map: DocumentIdOrderMapping,
|
||||
display_doc_order_dict: dict[str, int],
|
||||
stop_stream: str | None = STOP_STREAM_PAT,
|
||||
):
|
||||
self.context_docs = context_docs
|
||||
self.doc_id_to_rank_map = doc_id_to_rank_map
|
||||
self.stop_stream = stop_stream
|
||||
self.order_mapping = doc_id_to_rank_map.order_mapping
|
||||
self.display_doc_order_dict = (
|
||||
display_doc_order_dict # original order of docs to displayed to user
|
||||
)
|
||||
self.llm_out = ""
|
||||
self.max_citation_num = len(context_docs)
|
||||
self.citation_order: list[int] = []
|
||||
@@ -98,6 +102,18 @@ class CitationProcessor:
|
||||
self.citation_order.index(real_citation_num) + 1
|
||||
)
|
||||
|
||||
# get the value that was displayed to user, should always
|
||||
# be in the display_doc_order_dict. But check anyways
|
||||
if context_llm_doc.document_id in self.display_doc_order_dict:
|
||||
displayed_citation_num = self.display_doc_order_dict[
|
||||
context_llm_doc.document_id
|
||||
]
|
||||
else:
|
||||
displayed_citation_num = real_citation_num
|
||||
logger.warning(
|
||||
f"Doc {context_llm_doc.document_id} not in display_doc_order_dict. Used LLM citation number instead."
|
||||
)
|
||||
|
||||
# Skip consecutive citations of the same work
|
||||
if target_citation_num in self.current_citations:
|
||||
start, end = citation.span()
|
||||
@@ -118,6 +134,7 @@ class CitationProcessor:
|
||||
doc_id = int(match.group(1))
|
||||
context_llm_doc = self.context_docs[doc_id - 1]
|
||||
yield CitationInfo(
|
||||
# stay with the original for now (order of LLM cites)
|
||||
citation_num=target_citation_num,
|
||||
document_id=context_llm_doc.document_id,
|
||||
)
|
||||
@@ -139,6 +156,7 @@ class CitationProcessor:
|
||||
if target_citation_num not in self.cited_inds:
|
||||
self.cited_inds.add(target_citation_num)
|
||||
yield CitationInfo(
|
||||
# stay with the original for now (order of LLM cites)
|
||||
citation_num=target_citation_num,
|
||||
document_id=context_llm_doc.document_id,
|
||||
)
|
||||
@@ -148,7 +166,8 @@ class CitationProcessor:
|
||||
prev_length = len(self.curr_segment)
|
||||
self.curr_segment = (
|
||||
self.curr_segment[: start + length_to_add]
|
||||
+ f"[[{target_citation_num}]]({link})"
|
||||
+ f"[[{displayed_citation_num}]]({link})" # use the value that was displayed to user
|
||||
# + f"[[{target_citation_num}]]({link})"
|
||||
+ self.curr_segment[end + length_to_add :]
|
||||
)
|
||||
length_to_add += len(self.curr_segment) - prev_length
|
||||
@@ -156,7 +175,8 @@ class CitationProcessor:
|
||||
prev_length = len(self.curr_segment)
|
||||
self.curr_segment = (
|
||||
self.curr_segment[: start + length_to_add]
|
||||
+ f"[[{target_citation_num}]]()"
|
||||
+ f"[[{displayed_citation_num}]]()" # use the value that was displayed to user
|
||||
# + f"[[{target_citation_num}]]()"
|
||||
+ self.curr_segment[end + length_to_add :]
|
||||
)
|
||||
length_to_add += len(self.curr_segment) - prev_length
|
||||
|
@@ -48,6 +48,9 @@ from danswer.tools.tool_implementations.search_like_tool_utils import (
|
||||
from danswer.tools.tool_implementations.search_like_tool_utils import (
|
||||
FINAL_CONTEXT_DOCUMENTS_ID,
|
||||
)
|
||||
from danswer.tools.tool_implementations.search_like_tool_utils import (
|
||||
ORIGINAL_CONTEXT_DOCUMENTS_ID,
|
||||
)
|
||||
from danswer.utils.logger import setup_logger
|
||||
from danswer.utils.special_types import JSON_ro
|
||||
|
||||
@@ -391,15 +394,35 @@ class SearchTool(Tool):
|
||||
"""Other utility functions"""
|
||||
|
||||
@classmethod
|
||||
def get_search_result(cls, llm_call: LLMCall) -> list[LlmDoc] | None:
|
||||
def get_search_result(
|
||||
cls, llm_call: LLMCall
|
||||
) -> tuple[list[LlmDoc], dict[str, int]] | None:
|
||||
"""
|
||||
Returns the final search results and a map of docs to their original search rank (which is what is displayed to user)
|
||||
"""
|
||||
if not llm_call.tool_call_info:
|
||||
return None
|
||||
|
||||
final_search_results = []
|
||||
doc_id_to_original_search_rank_map = {}
|
||||
|
||||
for yield_item in llm_call.tool_call_info:
|
||||
if (
|
||||
isinstance(yield_item, ToolResponse)
|
||||
and yield_item.id == FINAL_CONTEXT_DOCUMENTS_ID
|
||||
):
|
||||
return cast(list[LlmDoc], yield_item.response)
|
||||
final_search_results = cast(list[LlmDoc], yield_item.response)
|
||||
elif (
|
||||
isinstance(yield_item, ToolResponse)
|
||||
and yield_item.id == ORIGINAL_CONTEXT_DOCUMENTS_ID
|
||||
):
|
||||
search_contexts = yield_item.response.contexts
|
||||
original_doc_search_rank = 1
|
||||
for idx, doc in enumerate(search_contexts):
|
||||
if doc.document_id not in doc_id_to_original_search_rank_map:
|
||||
doc_id_to_original_search_rank_map[
|
||||
doc.document_id
|
||||
] = original_doc_search_rank
|
||||
original_doc_search_rank += 1
|
||||
|
||||
return None
|
||||
return final_search_results, doc_id_to_original_search_rank_map
|
||||
|
@@ -15,6 +15,7 @@ from danswer.tools.message import ToolCallSummary
|
||||
from danswer.tools.models import ToolResponse
|
||||
|
||||
|
||||
ORIGINAL_CONTEXT_DOCUMENTS_ID = "search_doc_content"
|
||||
FINAL_CONTEXT_DOCUMENTS_ID = "final_context_documents"
|
||||
|
||||
|
||||
|
@@ -72,8 +72,10 @@ def process_text(
|
||||
processor = CitationProcessor(
|
||||
context_docs=mock_docs,
|
||||
doc_id_to_rank_map=mapping,
|
||||
display_doc_order_dict=mock_doc_id_to_rank_map,
|
||||
stop_stream=None,
|
||||
)
|
||||
|
||||
result: list[DanswerAnswerPiece | CitationInfo] = []
|
||||
for token in tokens:
|
||||
result.extend(processor.process_token(token))
|
||||
@@ -86,6 +88,7 @@ def process_text(
|
||||
final_answer_text += piece.answer_piece or ""
|
||||
elif isinstance(piece, CitationInfo):
|
||||
citations.append(piece)
|
||||
|
||||
return final_answer_text, citations
|
||||
|
||||
|
||||
|
@@ -0,0 +1,132 @@
|
||||
from datetime import datetime
|
||||
|
||||
import pytest
|
||||
|
||||
from danswer.chat.models import CitationInfo
|
||||
from danswer.chat.models import DanswerAnswerPiece
|
||||
from danswer.chat.models import LlmDoc
|
||||
from danswer.chat.stream_processing.citation_processing import CitationProcessor
|
||||
from danswer.chat.stream_processing.utils import DocumentIdOrderMapping
|
||||
from danswer.configs.constants import DocumentSource
|
||||
|
||||
|
||||
"""
|
||||
This module contains tests for the citation extraction functionality in Danswer,
|
||||
specifically the substitution of the number of document cited in the UI. (The LLM
|
||||
will see the sources post re-ranking and relevance check, the UI before these steps.)
|
||||
This module is a derivative of test_citation_processing.py.
|
||||
|
||||
The tests focusses specifically on the substitution of the number of document cited in the UI.
|
||||
|
||||
Key components:
|
||||
- mock_docs: A list of mock LlmDoc objects used for testing.
|
||||
- mock_doc_mapping: A dictionary mapping document IDs to their initial ranks.
|
||||
- mock_doc_mapping_rerank: A dictionary mapping document IDs to their ranks after re-ranking/relevance check.
|
||||
- process_text: A helper function that simulates the citation extraction process.
|
||||
- test_citation_extraction: A parametrized test function covering various citation scenarios.
|
||||
|
||||
To add new test cases:
|
||||
1. Add a new tuple to the @pytest.mark.parametrize decorator of test_citation_extraction.
|
||||
2. Each tuple should contain:
|
||||
- A descriptive test name (string)
|
||||
- Input tokens (list of strings)
|
||||
- Expected output text (string)
|
||||
- Expected citations (list of document IDs)
|
||||
"""
|
||||
|
||||
|
||||
mock_docs = [
|
||||
LlmDoc(
|
||||
document_id=f"doc_{int(id/2)}",
|
||||
content="Document is a doc",
|
||||
blurb=f"Document #{id}",
|
||||
semantic_identifier=f"Doc {id}",
|
||||
source_type=DocumentSource.WEB,
|
||||
metadata={},
|
||||
updated_at=datetime.now(),
|
||||
link=f"https://{int(id/2)}.com" if int(id / 2) % 2 == 0 else None,
|
||||
source_links={0: "https://mintlify.com/docs/settings/broken-links"},
|
||||
match_highlights=[],
|
||||
)
|
||||
for id in range(10)
|
||||
]
|
||||
|
||||
mock_doc_mapping = {
|
||||
"doc_0": 1,
|
||||
"doc_1": 2,
|
||||
"doc_2": 3,
|
||||
"doc_3": 4,
|
||||
"doc_4": 5,
|
||||
"doc_5": 6,
|
||||
}
|
||||
|
||||
mock_doc_mapping_rerank = {
|
||||
"doc_0": 2,
|
||||
"doc_1": 1,
|
||||
"doc_2": 4,
|
||||
"doc_3": 3,
|
||||
"doc_4": 6,
|
||||
"doc_5": 5,
|
||||
}
|
||||
|
||||
|
||||
@pytest.fixture
|
||||
def mock_data() -> tuple[list[LlmDoc], dict[str, int]]:
|
||||
return mock_docs, mock_doc_mapping
|
||||
|
||||
|
||||
def process_text(
|
||||
tokens: list[str], mock_data: tuple[list[LlmDoc], dict[str, int]]
|
||||
) -> tuple[str, list[CitationInfo]]:
|
||||
mock_docs, mock_doc_id_to_rank_map = mock_data
|
||||
mapping = DocumentIdOrderMapping(order_mapping=mock_doc_id_to_rank_map)
|
||||
processor = CitationProcessor(
|
||||
context_docs=mock_docs,
|
||||
doc_id_to_rank_map=mapping,
|
||||
display_doc_order_dict=mock_doc_mapping_rerank,
|
||||
stop_stream=None,
|
||||
)
|
||||
|
||||
result: list[DanswerAnswerPiece | CitationInfo] = []
|
||||
for token in tokens:
|
||||
result.extend(processor.process_token(token))
|
||||
result.extend(processor.process_token(None))
|
||||
|
||||
final_answer_text = ""
|
||||
citations = []
|
||||
for piece in result:
|
||||
if isinstance(piece, DanswerAnswerPiece):
|
||||
final_answer_text += piece.answer_piece or ""
|
||||
elif isinstance(piece, CitationInfo):
|
||||
citations.append(piece)
|
||||
|
||||
return final_answer_text, citations
|
||||
|
||||
|
||||
@pytest.mark.parametrize(
|
||||
"test_name, input_tokens, expected_text, expected_citations",
|
||||
[
|
||||
(
|
||||
"Single citation",
|
||||
["Gro", "wth! [", "1", "]", "."],
|
||||
"Growth! [[2]](https://0.com).",
|
||||
["doc_0"],
|
||||
),
|
||||
],
|
||||
)
|
||||
def test_citation_substitution(
|
||||
mock_data: tuple[list[LlmDoc], dict[str, int]],
|
||||
test_name: str,
|
||||
input_tokens: list[str],
|
||||
expected_text: str,
|
||||
expected_citations: list[str],
|
||||
) -> None:
|
||||
final_answer_text, citations = process_text(input_tokens, mock_data)
|
||||
assert (
|
||||
final_answer_text.strip() == expected_text.strip()
|
||||
), f"Test '{test_name}' failed: Final answer text does not match expected output."
|
||||
assert [
|
||||
citation.document_id for citation in citations
|
||||
] == expected_citations, (
|
||||
f"Test '{test_name}' failed: Citations do not match expected output."
|
||||
)
|
Reference in New Issue
Block a user