Mismatch issue of Documents shown and Citation number in text fix (#3421)

* Mismatch issue of Documents shown and Citation number in text fix

When document order presented to LLM differs from order shown to user, wrong doc numbers are cited.

Fix:
 - SearchTool.get_search_result  returns now final and initial ranking
 - initial ranking is passed through a few objects and used for replacement in citation processing

Notes:
 - the citation_num in the CitationInfo() object has not been changed.

* PR fixes

 - linting
 - removed erroneous tab
 - added a substitution test case
 - adjusted original citation extraction use case

* Included a key test and

* Fixed extra spaces

* Updated test documentation

Updated:
 - test_citation_substitution (changed description)
 - test_citation_processing (removed data only relevant for the substitution)
This commit is contained in:
joachim-danswer
2024-12-11 11:58:24 -08:00
committed by GitHub
parent 71421bb782
commit 9455576078
7 changed files with 194 additions and 7 deletions

View File

@@ -206,7 +206,9 @@ class Answer:
# + figure out what the next LLM call should be
tool_call_handler = ToolResponseHandler(current_llm_call.tools)
search_result = SearchTool.get_search_result(current_llm_call) or []
search_result, displayed_search_results_map = SearchTool.get_search_result(
current_llm_call
) or ([], {})
# Quotes are no longer supported
# answer_handler: AnswerResponseHandler
@@ -224,6 +226,7 @@ class Answer:
answer_handler = CitationResponseHandler(
context_docs=search_result,
doc_id_to_rank_map=map_document_id_order(search_result),
display_doc_order_dict=displayed_search_results_map,
)
response_handler_manager = LLMResponseHandlerManager(

View File

@@ -35,13 +35,18 @@ class DummyAnswerResponseHandler(AnswerResponseHandler):
class CitationResponseHandler(AnswerResponseHandler):
def __init__(
self, context_docs: list[LlmDoc], doc_id_to_rank_map: DocumentIdOrderMapping
self,
context_docs: list[LlmDoc],
doc_id_to_rank_map: DocumentIdOrderMapping,
display_doc_order_dict: dict[str, int],
):
self.context_docs = context_docs
self.doc_id_to_rank_map = doc_id_to_rank_map
self.display_doc_order_dict = display_doc_order_dict
self.citation_processor = CitationProcessor(
context_docs=self.context_docs,
doc_id_to_rank_map=self.doc_id_to_rank_map,
display_doc_order_dict=self.display_doc_order_dict,
)
self.processed_text = ""
self.citations: list[CitationInfo] = []

View File

@@ -22,12 +22,16 @@ class CitationProcessor:
self,
context_docs: list[LlmDoc],
doc_id_to_rank_map: DocumentIdOrderMapping,
display_doc_order_dict: dict[str, int],
stop_stream: str | None = STOP_STREAM_PAT,
):
self.context_docs = context_docs
self.doc_id_to_rank_map = doc_id_to_rank_map
self.stop_stream = stop_stream
self.order_mapping = doc_id_to_rank_map.order_mapping
self.display_doc_order_dict = (
display_doc_order_dict # original order of docs to displayed to user
)
self.llm_out = ""
self.max_citation_num = len(context_docs)
self.citation_order: list[int] = []
@@ -98,6 +102,18 @@ class CitationProcessor:
self.citation_order.index(real_citation_num) + 1
)
# get the value that was displayed to user, should always
# be in the display_doc_order_dict. But check anyways
if context_llm_doc.document_id in self.display_doc_order_dict:
displayed_citation_num = self.display_doc_order_dict[
context_llm_doc.document_id
]
else:
displayed_citation_num = real_citation_num
logger.warning(
f"Doc {context_llm_doc.document_id} not in display_doc_order_dict. Used LLM citation number instead."
)
# Skip consecutive citations of the same work
if target_citation_num in self.current_citations:
start, end = citation.span()
@@ -118,6 +134,7 @@ class CitationProcessor:
doc_id = int(match.group(1))
context_llm_doc = self.context_docs[doc_id - 1]
yield CitationInfo(
# stay with the original for now (order of LLM cites)
citation_num=target_citation_num,
document_id=context_llm_doc.document_id,
)
@@ -139,6 +156,7 @@ class CitationProcessor:
if target_citation_num not in self.cited_inds:
self.cited_inds.add(target_citation_num)
yield CitationInfo(
# stay with the original for now (order of LLM cites)
citation_num=target_citation_num,
document_id=context_llm_doc.document_id,
)
@@ -148,7 +166,8 @@ class CitationProcessor:
prev_length = len(self.curr_segment)
self.curr_segment = (
self.curr_segment[: start + length_to_add]
+ f"[[{target_citation_num}]]({link})"
+ f"[[{displayed_citation_num}]]({link})" # use the value that was displayed to user
# + f"[[{target_citation_num}]]({link})"
+ self.curr_segment[end + length_to_add :]
)
length_to_add += len(self.curr_segment) - prev_length
@@ -156,7 +175,8 @@ class CitationProcessor:
prev_length = len(self.curr_segment)
self.curr_segment = (
self.curr_segment[: start + length_to_add]
+ f"[[{target_citation_num}]]()"
+ f"[[{displayed_citation_num}]]()" # use the value that was displayed to user
# + f"[[{target_citation_num}]]()"
+ self.curr_segment[end + length_to_add :]
)
length_to_add += len(self.curr_segment) - prev_length

View File

@@ -48,6 +48,9 @@ from danswer.tools.tool_implementations.search_like_tool_utils import (
from danswer.tools.tool_implementations.search_like_tool_utils import (
FINAL_CONTEXT_DOCUMENTS_ID,
)
from danswer.tools.tool_implementations.search_like_tool_utils import (
ORIGINAL_CONTEXT_DOCUMENTS_ID,
)
from danswer.utils.logger import setup_logger
from danswer.utils.special_types import JSON_ro
@@ -391,15 +394,35 @@ class SearchTool(Tool):
"""Other utility functions"""
@classmethod
def get_search_result(cls, llm_call: LLMCall) -> list[LlmDoc] | None:
def get_search_result(
cls, llm_call: LLMCall
) -> tuple[list[LlmDoc], dict[str, int]] | None:
"""
Returns the final search results and a map of docs to their original search rank (which is what is displayed to user)
"""
if not llm_call.tool_call_info:
return None
final_search_results = []
doc_id_to_original_search_rank_map = {}
for yield_item in llm_call.tool_call_info:
if (
isinstance(yield_item, ToolResponse)
and yield_item.id == FINAL_CONTEXT_DOCUMENTS_ID
):
return cast(list[LlmDoc], yield_item.response)
final_search_results = cast(list[LlmDoc], yield_item.response)
elif (
isinstance(yield_item, ToolResponse)
and yield_item.id == ORIGINAL_CONTEXT_DOCUMENTS_ID
):
search_contexts = yield_item.response.contexts
original_doc_search_rank = 1
for idx, doc in enumerate(search_contexts):
if doc.document_id not in doc_id_to_original_search_rank_map:
doc_id_to_original_search_rank_map[
doc.document_id
] = original_doc_search_rank
original_doc_search_rank += 1
return None
return final_search_results, doc_id_to_original_search_rank_map

View File

@@ -15,6 +15,7 @@ from danswer.tools.message import ToolCallSummary
from danswer.tools.models import ToolResponse
ORIGINAL_CONTEXT_DOCUMENTS_ID = "search_doc_content"
FINAL_CONTEXT_DOCUMENTS_ID = "final_context_documents"

View File

@@ -72,8 +72,10 @@ def process_text(
processor = CitationProcessor(
context_docs=mock_docs,
doc_id_to_rank_map=mapping,
display_doc_order_dict=mock_doc_id_to_rank_map,
stop_stream=None,
)
result: list[DanswerAnswerPiece | CitationInfo] = []
for token in tokens:
result.extend(processor.process_token(token))
@@ -86,6 +88,7 @@ def process_text(
final_answer_text += piece.answer_piece or ""
elif isinstance(piece, CitationInfo):
citations.append(piece)
return final_answer_text, citations

View File

@@ -0,0 +1,132 @@
from datetime import datetime
import pytest
from danswer.chat.models import CitationInfo
from danswer.chat.models import DanswerAnswerPiece
from danswer.chat.models import LlmDoc
from danswer.chat.stream_processing.citation_processing import CitationProcessor
from danswer.chat.stream_processing.utils import DocumentIdOrderMapping
from danswer.configs.constants import DocumentSource
"""
This module contains tests for the citation extraction functionality in Danswer,
specifically the substitution of the number of document cited in the UI. (The LLM
will see the sources post re-ranking and relevance check, the UI before these steps.)
This module is a derivative of test_citation_processing.py.
The tests focusses specifically on the substitution of the number of document cited in the UI.
Key components:
- mock_docs: A list of mock LlmDoc objects used for testing.
- mock_doc_mapping: A dictionary mapping document IDs to their initial ranks.
- mock_doc_mapping_rerank: A dictionary mapping document IDs to their ranks after re-ranking/relevance check.
- process_text: A helper function that simulates the citation extraction process.
- test_citation_extraction: A parametrized test function covering various citation scenarios.
To add new test cases:
1. Add a new tuple to the @pytest.mark.parametrize decorator of test_citation_extraction.
2. Each tuple should contain:
- A descriptive test name (string)
- Input tokens (list of strings)
- Expected output text (string)
- Expected citations (list of document IDs)
"""
mock_docs = [
LlmDoc(
document_id=f"doc_{int(id/2)}",
content="Document is a doc",
blurb=f"Document #{id}",
semantic_identifier=f"Doc {id}",
source_type=DocumentSource.WEB,
metadata={},
updated_at=datetime.now(),
link=f"https://{int(id/2)}.com" if int(id / 2) % 2 == 0 else None,
source_links={0: "https://mintlify.com/docs/settings/broken-links"},
match_highlights=[],
)
for id in range(10)
]
mock_doc_mapping = {
"doc_0": 1,
"doc_1": 2,
"doc_2": 3,
"doc_3": 4,
"doc_4": 5,
"doc_5": 6,
}
mock_doc_mapping_rerank = {
"doc_0": 2,
"doc_1": 1,
"doc_2": 4,
"doc_3": 3,
"doc_4": 6,
"doc_5": 5,
}
@pytest.fixture
def mock_data() -> tuple[list[LlmDoc], dict[str, int]]:
return mock_docs, mock_doc_mapping
def process_text(
tokens: list[str], mock_data: tuple[list[LlmDoc], dict[str, int]]
) -> tuple[str, list[CitationInfo]]:
mock_docs, mock_doc_id_to_rank_map = mock_data
mapping = DocumentIdOrderMapping(order_mapping=mock_doc_id_to_rank_map)
processor = CitationProcessor(
context_docs=mock_docs,
doc_id_to_rank_map=mapping,
display_doc_order_dict=mock_doc_mapping_rerank,
stop_stream=None,
)
result: list[DanswerAnswerPiece | CitationInfo] = []
for token in tokens:
result.extend(processor.process_token(token))
result.extend(processor.process_token(None))
final_answer_text = ""
citations = []
for piece in result:
if isinstance(piece, DanswerAnswerPiece):
final_answer_text += piece.answer_piece or ""
elif isinstance(piece, CitationInfo):
citations.append(piece)
return final_answer_text, citations
@pytest.mark.parametrize(
"test_name, input_tokens, expected_text, expected_citations",
[
(
"Single citation",
["Gro", "wth! [", "1", "]", "."],
"Growth! [[2]](https://0.com).",
["doc_0"],
),
],
)
def test_citation_substitution(
mock_data: tuple[list[LlmDoc], dict[str, int]],
test_name: str,
input_tokens: list[str],
expected_text: str,
expected_citations: list[str],
) -> None:
final_answer_text, citations = process_text(input_tokens, mock_data)
assert (
final_answer_text.strip() == expected_text.strip()
), f"Test '{test_name}' failed: Final answer text does not match expected output."
assert [
citation.document_id for citation in citations
] == expected_citations, (
f"Test '{test_name}' failed: Citations do not match expected output."
)