mirror of
https://github.com/danswer-ai/danswer.git
synced 2025-09-19 20:24:32 +02:00
Mismatch issue of Documents shown and Citation number in text fix (#3421)
* Mismatch issue of Documents shown and Citation number in text fix When document order presented to LLM differs from order shown to user, wrong doc numbers are cited. Fix: - SearchTool.get_search_result returns now final and initial ranking - initial ranking is passed through a few objects and used for replacement in citation processing Notes: - the citation_num in the CitationInfo() object has not been changed. * PR fixes - linting - removed erroneous tab - added a substitution test case - adjusted original citation extraction use case * Included a key test and * Fixed extra spaces * Updated test documentation Updated: - test_citation_substitution (changed description) - test_citation_processing (removed data only relevant for the substitution)
This commit is contained in:
@@ -206,7 +206,9 @@ class Answer:
|
|||||||
# + figure out what the next LLM call should be
|
# + figure out what the next LLM call should be
|
||||||
tool_call_handler = ToolResponseHandler(current_llm_call.tools)
|
tool_call_handler = ToolResponseHandler(current_llm_call.tools)
|
||||||
|
|
||||||
search_result = SearchTool.get_search_result(current_llm_call) or []
|
search_result, displayed_search_results_map = SearchTool.get_search_result(
|
||||||
|
current_llm_call
|
||||||
|
) or ([], {})
|
||||||
|
|
||||||
# Quotes are no longer supported
|
# Quotes are no longer supported
|
||||||
# answer_handler: AnswerResponseHandler
|
# answer_handler: AnswerResponseHandler
|
||||||
@@ -224,6 +226,7 @@ class Answer:
|
|||||||
answer_handler = CitationResponseHandler(
|
answer_handler = CitationResponseHandler(
|
||||||
context_docs=search_result,
|
context_docs=search_result,
|
||||||
doc_id_to_rank_map=map_document_id_order(search_result),
|
doc_id_to_rank_map=map_document_id_order(search_result),
|
||||||
|
display_doc_order_dict=displayed_search_results_map,
|
||||||
)
|
)
|
||||||
|
|
||||||
response_handler_manager = LLMResponseHandlerManager(
|
response_handler_manager = LLMResponseHandlerManager(
|
||||||
|
@@ -35,13 +35,18 @@ class DummyAnswerResponseHandler(AnswerResponseHandler):
|
|||||||
|
|
||||||
class CitationResponseHandler(AnswerResponseHandler):
|
class CitationResponseHandler(AnswerResponseHandler):
|
||||||
def __init__(
|
def __init__(
|
||||||
self, context_docs: list[LlmDoc], doc_id_to_rank_map: DocumentIdOrderMapping
|
self,
|
||||||
|
context_docs: list[LlmDoc],
|
||||||
|
doc_id_to_rank_map: DocumentIdOrderMapping,
|
||||||
|
display_doc_order_dict: dict[str, int],
|
||||||
):
|
):
|
||||||
self.context_docs = context_docs
|
self.context_docs = context_docs
|
||||||
self.doc_id_to_rank_map = doc_id_to_rank_map
|
self.doc_id_to_rank_map = doc_id_to_rank_map
|
||||||
|
self.display_doc_order_dict = display_doc_order_dict
|
||||||
self.citation_processor = CitationProcessor(
|
self.citation_processor = CitationProcessor(
|
||||||
context_docs=self.context_docs,
|
context_docs=self.context_docs,
|
||||||
doc_id_to_rank_map=self.doc_id_to_rank_map,
|
doc_id_to_rank_map=self.doc_id_to_rank_map,
|
||||||
|
display_doc_order_dict=self.display_doc_order_dict,
|
||||||
)
|
)
|
||||||
self.processed_text = ""
|
self.processed_text = ""
|
||||||
self.citations: list[CitationInfo] = []
|
self.citations: list[CitationInfo] = []
|
||||||
|
@@ -22,12 +22,16 @@ class CitationProcessor:
|
|||||||
self,
|
self,
|
||||||
context_docs: list[LlmDoc],
|
context_docs: list[LlmDoc],
|
||||||
doc_id_to_rank_map: DocumentIdOrderMapping,
|
doc_id_to_rank_map: DocumentIdOrderMapping,
|
||||||
|
display_doc_order_dict: dict[str, int],
|
||||||
stop_stream: str | None = STOP_STREAM_PAT,
|
stop_stream: str | None = STOP_STREAM_PAT,
|
||||||
):
|
):
|
||||||
self.context_docs = context_docs
|
self.context_docs = context_docs
|
||||||
self.doc_id_to_rank_map = doc_id_to_rank_map
|
self.doc_id_to_rank_map = doc_id_to_rank_map
|
||||||
self.stop_stream = stop_stream
|
self.stop_stream = stop_stream
|
||||||
self.order_mapping = doc_id_to_rank_map.order_mapping
|
self.order_mapping = doc_id_to_rank_map.order_mapping
|
||||||
|
self.display_doc_order_dict = (
|
||||||
|
display_doc_order_dict # original order of docs to displayed to user
|
||||||
|
)
|
||||||
self.llm_out = ""
|
self.llm_out = ""
|
||||||
self.max_citation_num = len(context_docs)
|
self.max_citation_num = len(context_docs)
|
||||||
self.citation_order: list[int] = []
|
self.citation_order: list[int] = []
|
||||||
@@ -98,6 +102,18 @@ class CitationProcessor:
|
|||||||
self.citation_order.index(real_citation_num) + 1
|
self.citation_order.index(real_citation_num) + 1
|
||||||
)
|
)
|
||||||
|
|
||||||
|
# get the value that was displayed to user, should always
|
||||||
|
# be in the display_doc_order_dict. But check anyways
|
||||||
|
if context_llm_doc.document_id in self.display_doc_order_dict:
|
||||||
|
displayed_citation_num = self.display_doc_order_dict[
|
||||||
|
context_llm_doc.document_id
|
||||||
|
]
|
||||||
|
else:
|
||||||
|
displayed_citation_num = real_citation_num
|
||||||
|
logger.warning(
|
||||||
|
f"Doc {context_llm_doc.document_id} not in display_doc_order_dict. Used LLM citation number instead."
|
||||||
|
)
|
||||||
|
|
||||||
# Skip consecutive citations of the same work
|
# Skip consecutive citations of the same work
|
||||||
if target_citation_num in self.current_citations:
|
if target_citation_num in self.current_citations:
|
||||||
start, end = citation.span()
|
start, end = citation.span()
|
||||||
@@ -118,6 +134,7 @@ class CitationProcessor:
|
|||||||
doc_id = int(match.group(1))
|
doc_id = int(match.group(1))
|
||||||
context_llm_doc = self.context_docs[doc_id - 1]
|
context_llm_doc = self.context_docs[doc_id - 1]
|
||||||
yield CitationInfo(
|
yield CitationInfo(
|
||||||
|
# stay with the original for now (order of LLM cites)
|
||||||
citation_num=target_citation_num,
|
citation_num=target_citation_num,
|
||||||
document_id=context_llm_doc.document_id,
|
document_id=context_llm_doc.document_id,
|
||||||
)
|
)
|
||||||
@@ -139,6 +156,7 @@ class CitationProcessor:
|
|||||||
if target_citation_num not in self.cited_inds:
|
if target_citation_num not in self.cited_inds:
|
||||||
self.cited_inds.add(target_citation_num)
|
self.cited_inds.add(target_citation_num)
|
||||||
yield CitationInfo(
|
yield CitationInfo(
|
||||||
|
# stay with the original for now (order of LLM cites)
|
||||||
citation_num=target_citation_num,
|
citation_num=target_citation_num,
|
||||||
document_id=context_llm_doc.document_id,
|
document_id=context_llm_doc.document_id,
|
||||||
)
|
)
|
||||||
@@ -148,7 +166,8 @@ class CitationProcessor:
|
|||||||
prev_length = len(self.curr_segment)
|
prev_length = len(self.curr_segment)
|
||||||
self.curr_segment = (
|
self.curr_segment = (
|
||||||
self.curr_segment[: start + length_to_add]
|
self.curr_segment[: start + length_to_add]
|
||||||
+ f"[[{target_citation_num}]]({link})"
|
+ f"[[{displayed_citation_num}]]({link})" # use the value that was displayed to user
|
||||||
|
# + f"[[{target_citation_num}]]({link})"
|
||||||
+ self.curr_segment[end + length_to_add :]
|
+ self.curr_segment[end + length_to_add :]
|
||||||
)
|
)
|
||||||
length_to_add += len(self.curr_segment) - prev_length
|
length_to_add += len(self.curr_segment) - prev_length
|
||||||
@@ -156,7 +175,8 @@ class CitationProcessor:
|
|||||||
prev_length = len(self.curr_segment)
|
prev_length = len(self.curr_segment)
|
||||||
self.curr_segment = (
|
self.curr_segment = (
|
||||||
self.curr_segment[: start + length_to_add]
|
self.curr_segment[: start + length_to_add]
|
||||||
+ f"[[{target_citation_num}]]()"
|
+ f"[[{displayed_citation_num}]]()" # use the value that was displayed to user
|
||||||
|
# + f"[[{target_citation_num}]]()"
|
||||||
+ self.curr_segment[end + length_to_add :]
|
+ self.curr_segment[end + length_to_add :]
|
||||||
)
|
)
|
||||||
length_to_add += len(self.curr_segment) - prev_length
|
length_to_add += len(self.curr_segment) - prev_length
|
||||||
|
@@ -48,6 +48,9 @@ from danswer.tools.tool_implementations.search_like_tool_utils import (
|
|||||||
from danswer.tools.tool_implementations.search_like_tool_utils import (
|
from danswer.tools.tool_implementations.search_like_tool_utils import (
|
||||||
FINAL_CONTEXT_DOCUMENTS_ID,
|
FINAL_CONTEXT_DOCUMENTS_ID,
|
||||||
)
|
)
|
||||||
|
from danswer.tools.tool_implementations.search_like_tool_utils import (
|
||||||
|
ORIGINAL_CONTEXT_DOCUMENTS_ID,
|
||||||
|
)
|
||||||
from danswer.utils.logger import setup_logger
|
from danswer.utils.logger import setup_logger
|
||||||
from danswer.utils.special_types import JSON_ro
|
from danswer.utils.special_types import JSON_ro
|
||||||
|
|
||||||
@@ -391,15 +394,35 @@ class SearchTool(Tool):
|
|||||||
"""Other utility functions"""
|
"""Other utility functions"""
|
||||||
|
|
||||||
@classmethod
|
@classmethod
|
||||||
def get_search_result(cls, llm_call: LLMCall) -> list[LlmDoc] | None:
|
def get_search_result(
|
||||||
|
cls, llm_call: LLMCall
|
||||||
|
) -> tuple[list[LlmDoc], dict[str, int]] | None:
|
||||||
|
"""
|
||||||
|
Returns the final search results and a map of docs to their original search rank (which is what is displayed to user)
|
||||||
|
"""
|
||||||
if not llm_call.tool_call_info:
|
if not llm_call.tool_call_info:
|
||||||
return None
|
return None
|
||||||
|
|
||||||
|
final_search_results = []
|
||||||
|
doc_id_to_original_search_rank_map = {}
|
||||||
|
|
||||||
for yield_item in llm_call.tool_call_info:
|
for yield_item in llm_call.tool_call_info:
|
||||||
if (
|
if (
|
||||||
isinstance(yield_item, ToolResponse)
|
isinstance(yield_item, ToolResponse)
|
||||||
and yield_item.id == FINAL_CONTEXT_DOCUMENTS_ID
|
and yield_item.id == FINAL_CONTEXT_DOCUMENTS_ID
|
||||||
):
|
):
|
||||||
return cast(list[LlmDoc], yield_item.response)
|
final_search_results = cast(list[LlmDoc], yield_item.response)
|
||||||
|
elif (
|
||||||
|
isinstance(yield_item, ToolResponse)
|
||||||
|
and yield_item.id == ORIGINAL_CONTEXT_DOCUMENTS_ID
|
||||||
|
):
|
||||||
|
search_contexts = yield_item.response.contexts
|
||||||
|
original_doc_search_rank = 1
|
||||||
|
for idx, doc in enumerate(search_contexts):
|
||||||
|
if doc.document_id not in doc_id_to_original_search_rank_map:
|
||||||
|
doc_id_to_original_search_rank_map[
|
||||||
|
doc.document_id
|
||||||
|
] = original_doc_search_rank
|
||||||
|
original_doc_search_rank += 1
|
||||||
|
|
||||||
return None
|
return final_search_results, doc_id_to_original_search_rank_map
|
||||||
|
@@ -15,6 +15,7 @@ from danswer.tools.message import ToolCallSummary
|
|||||||
from danswer.tools.models import ToolResponse
|
from danswer.tools.models import ToolResponse
|
||||||
|
|
||||||
|
|
||||||
|
ORIGINAL_CONTEXT_DOCUMENTS_ID = "search_doc_content"
|
||||||
FINAL_CONTEXT_DOCUMENTS_ID = "final_context_documents"
|
FINAL_CONTEXT_DOCUMENTS_ID = "final_context_documents"
|
||||||
|
|
||||||
|
|
||||||
|
@@ -72,8 +72,10 @@ def process_text(
|
|||||||
processor = CitationProcessor(
|
processor = CitationProcessor(
|
||||||
context_docs=mock_docs,
|
context_docs=mock_docs,
|
||||||
doc_id_to_rank_map=mapping,
|
doc_id_to_rank_map=mapping,
|
||||||
|
display_doc_order_dict=mock_doc_id_to_rank_map,
|
||||||
stop_stream=None,
|
stop_stream=None,
|
||||||
)
|
)
|
||||||
|
|
||||||
result: list[DanswerAnswerPiece | CitationInfo] = []
|
result: list[DanswerAnswerPiece | CitationInfo] = []
|
||||||
for token in tokens:
|
for token in tokens:
|
||||||
result.extend(processor.process_token(token))
|
result.extend(processor.process_token(token))
|
||||||
@@ -86,6 +88,7 @@ def process_text(
|
|||||||
final_answer_text += piece.answer_piece or ""
|
final_answer_text += piece.answer_piece or ""
|
||||||
elif isinstance(piece, CitationInfo):
|
elif isinstance(piece, CitationInfo):
|
||||||
citations.append(piece)
|
citations.append(piece)
|
||||||
|
|
||||||
return final_answer_text, citations
|
return final_answer_text, citations
|
||||||
|
|
||||||
|
|
||||||
|
@@ -0,0 +1,132 @@
|
|||||||
|
from datetime import datetime
|
||||||
|
|
||||||
|
import pytest
|
||||||
|
|
||||||
|
from danswer.chat.models import CitationInfo
|
||||||
|
from danswer.chat.models import DanswerAnswerPiece
|
||||||
|
from danswer.chat.models import LlmDoc
|
||||||
|
from danswer.chat.stream_processing.citation_processing import CitationProcessor
|
||||||
|
from danswer.chat.stream_processing.utils import DocumentIdOrderMapping
|
||||||
|
from danswer.configs.constants import DocumentSource
|
||||||
|
|
||||||
|
|
||||||
|
"""
|
||||||
|
This module contains tests for the citation extraction functionality in Danswer,
|
||||||
|
specifically the substitution of the number of document cited in the UI. (The LLM
|
||||||
|
will see the sources post re-ranking and relevance check, the UI before these steps.)
|
||||||
|
This module is a derivative of test_citation_processing.py.
|
||||||
|
|
||||||
|
The tests focusses specifically on the substitution of the number of document cited in the UI.
|
||||||
|
|
||||||
|
Key components:
|
||||||
|
- mock_docs: A list of mock LlmDoc objects used for testing.
|
||||||
|
- mock_doc_mapping: A dictionary mapping document IDs to their initial ranks.
|
||||||
|
- mock_doc_mapping_rerank: A dictionary mapping document IDs to their ranks after re-ranking/relevance check.
|
||||||
|
- process_text: A helper function that simulates the citation extraction process.
|
||||||
|
- test_citation_extraction: A parametrized test function covering various citation scenarios.
|
||||||
|
|
||||||
|
To add new test cases:
|
||||||
|
1. Add a new tuple to the @pytest.mark.parametrize decorator of test_citation_extraction.
|
||||||
|
2. Each tuple should contain:
|
||||||
|
- A descriptive test name (string)
|
||||||
|
- Input tokens (list of strings)
|
||||||
|
- Expected output text (string)
|
||||||
|
- Expected citations (list of document IDs)
|
||||||
|
"""
|
||||||
|
|
||||||
|
|
||||||
|
mock_docs = [
|
||||||
|
LlmDoc(
|
||||||
|
document_id=f"doc_{int(id/2)}",
|
||||||
|
content="Document is a doc",
|
||||||
|
blurb=f"Document #{id}",
|
||||||
|
semantic_identifier=f"Doc {id}",
|
||||||
|
source_type=DocumentSource.WEB,
|
||||||
|
metadata={},
|
||||||
|
updated_at=datetime.now(),
|
||||||
|
link=f"https://{int(id/2)}.com" if int(id / 2) % 2 == 0 else None,
|
||||||
|
source_links={0: "https://mintlify.com/docs/settings/broken-links"},
|
||||||
|
match_highlights=[],
|
||||||
|
)
|
||||||
|
for id in range(10)
|
||||||
|
]
|
||||||
|
|
||||||
|
mock_doc_mapping = {
|
||||||
|
"doc_0": 1,
|
||||||
|
"doc_1": 2,
|
||||||
|
"doc_2": 3,
|
||||||
|
"doc_3": 4,
|
||||||
|
"doc_4": 5,
|
||||||
|
"doc_5": 6,
|
||||||
|
}
|
||||||
|
|
||||||
|
mock_doc_mapping_rerank = {
|
||||||
|
"doc_0": 2,
|
||||||
|
"doc_1": 1,
|
||||||
|
"doc_2": 4,
|
||||||
|
"doc_3": 3,
|
||||||
|
"doc_4": 6,
|
||||||
|
"doc_5": 5,
|
||||||
|
}
|
||||||
|
|
||||||
|
|
||||||
|
@pytest.fixture
|
||||||
|
def mock_data() -> tuple[list[LlmDoc], dict[str, int]]:
|
||||||
|
return mock_docs, mock_doc_mapping
|
||||||
|
|
||||||
|
|
||||||
|
def process_text(
|
||||||
|
tokens: list[str], mock_data: tuple[list[LlmDoc], dict[str, int]]
|
||||||
|
) -> tuple[str, list[CitationInfo]]:
|
||||||
|
mock_docs, mock_doc_id_to_rank_map = mock_data
|
||||||
|
mapping = DocumentIdOrderMapping(order_mapping=mock_doc_id_to_rank_map)
|
||||||
|
processor = CitationProcessor(
|
||||||
|
context_docs=mock_docs,
|
||||||
|
doc_id_to_rank_map=mapping,
|
||||||
|
display_doc_order_dict=mock_doc_mapping_rerank,
|
||||||
|
stop_stream=None,
|
||||||
|
)
|
||||||
|
|
||||||
|
result: list[DanswerAnswerPiece | CitationInfo] = []
|
||||||
|
for token in tokens:
|
||||||
|
result.extend(processor.process_token(token))
|
||||||
|
result.extend(processor.process_token(None))
|
||||||
|
|
||||||
|
final_answer_text = ""
|
||||||
|
citations = []
|
||||||
|
for piece in result:
|
||||||
|
if isinstance(piece, DanswerAnswerPiece):
|
||||||
|
final_answer_text += piece.answer_piece or ""
|
||||||
|
elif isinstance(piece, CitationInfo):
|
||||||
|
citations.append(piece)
|
||||||
|
|
||||||
|
return final_answer_text, citations
|
||||||
|
|
||||||
|
|
||||||
|
@pytest.mark.parametrize(
|
||||||
|
"test_name, input_tokens, expected_text, expected_citations",
|
||||||
|
[
|
||||||
|
(
|
||||||
|
"Single citation",
|
||||||
|
["Gro", "wth! [", "1", "]", "."],
|
||||||
|
"Growth! [[2]](https://0.com).",
|
||||||
|
["doc_0"],
|
||||||
|
),
|
||||||
|
],
|
||||||
|
)
|
||||||
|
def test_citation_substitution(
|
||||||
|
mock_data: tuple[list[LlmDoc], dict[str, int]],
|
||||||
|
test_name: str,
|
||||||
|
input_tokens: list[str],
|
||||||
|
expected_text: str,
|
||||||
|
expected_citations: list[str],
|
||||||
|
) -> None:
|
||||||
|
final_answer_text, citations = process_text(input_tokens, mock_data)
|
||||||
|
assert (
|
||||||
|
final_answer_text.strip() == expected_text.strip()
|
||||||
|
), f"Test '{test_name}' failed: Final answer text does not match expected output."
|
||||||
|
assert [
|
||||||
|
citation.document_id for citation in citations
|
||||||
|
] == expected_citations, (
|
||||||
|
f"Test '{test_name}' failed: Citations do not match expected output."
|
||||||
|
)
|
Reference in New Issue
Block a user