mirror of
https://github.com/danswer-ai/danswer.git
synced 2025-09-19 12:03:54 +02:00
alignment & renaming of objects for initial (displayed) ranking and re-ranking/validation citations
- renamed post-reranking/validation citation information consistently to final_... (example: doc_id_to_rank_map -> final_doc_id_to_rank_map) - changed and renamed objects containing initial ranking information (now: display_...) consistent with final rankings (final_...). Specifically, {} to [] for displayed_search_results - for CitationInfo, changed citation_num from 'x-th citation in response stream' to the initial position of the doc [NOTE: test implications] - changed tests: onyx/backend/tests/unit/onyx/chat/stream_processing/test_citation_processing.py onyx/backend/tests/unit/onyx/chat/stream_processing/test_citation_substitution.py
This commit is contained in:
committed by
Chris Weaver
parent
27699c8216
commit
8750f14647
@@ -22,7 +22,9 @@ from onyx.chat.stream_processing.answer_response_handler import (
|
|||||||
from onyx.chat.stream_processing.answer_response_handler import (
|
from onyx.chat.stream_processing.answer_response_handler import (
|
||||||
DummyAnswerResponseHandler,
|
DummyAnswerResponseHandler,
|
||||||
)
|
)
|
||||||
from onyx.chat.stream_processing.utils import map_document_id_order
|
from onyx.chat.stream_processing.utils import (
|
||||||
|
map_document_id_order,
|
||||||
|
)
|
||||||
from onyx.chat.tool_handling.tool_response_handler import ToolResponseHandler
|
from onyx.chat.tool_handling.tool_response_handler import ToolResponseHandler
|
||||||
from onyx.file_store.utils import InMemoryChatFile
|
from onyx.file_store.utils import InMemoryChatFile
|
||||||
from onyx.llm.interfaces import LLM
|
from onyx.llm.interfaces import LLM
|
||||||
@@ -206,9 +208,9 @@ class Answer:
|
|||||||
# + figure out what the next LLM call should be
|
# + figure out what the next LLM call should be
|
||||||
tool_call_handler = ToolResponseHandler(current_llm_call.tools)
|
tool_call_handler = ToolResponseHandler(current_llm_call.tools)
|
||||||
|
|
||||||
search_result, displayed_search_results_map = SearchTool.get_search_result(
|
final_search_results, displayed_search_results = SearchTool.get_search_result(
|
||||||
current_llm_call
|
current_llm_call
|
||||||
) or ([], {})
|
) or ([], [])
|
||||||
|
|
||||||
# Quotes are no longer supported
|
# Quotes are no longer supported
|
||||||
# answer_handler: AnswerResponseHandler
|
# answer_handler: AnswerResponseHandler
|
||||||
@@ -224,9 +226,9 @@ class Answer:
|
|||||||
# else:
|
# else:
|
||||||
# raise ValueError("No answer style config provided")
|
# raise ValueError("No answer style config provided")
|
||||||
answer_handler = CitationResponseHandler(
|
answer_handler = CitationResponseHandler(
|
||||||
context_docs=search_result,
|
context_docs=final_search_results,
|
||||||
doc_id_to_rank_map=map_document_id_order(search_result),
|
final_doc_id_to_rank_map=map_document_id_order(final_search_results),
|
||||||
display_doc_order_dict=displayed_search_results_map,
|
display_doc_id_to_rank_map=map_document_id_order(displayed_search_results),
|
||||||
)
|
)
|
||||||
|
|
||||||
response_handler_manager = LLMResponseHandlerManager(
|
response_handler_manager = LLMResponseHandlerManager(
|
||||||
|
@@ -37,22 +37,22 @@ class CitationResponseHandler(AnswerResponseHandler):
|
|||||||
def __init__(
|
def __init__(
|
||||||
self,
|
self,
|
||||||
context_docs: list[LlmDoc],
|
context_docs: list[LlmDoc],
|
||||||
doc_id_to_rank_map: DocumentIdOrderMapping,
|
final_doc_id_to_rank_map: DocumentIdOrderMapping,
|
||||||
display_doc_order_dict: dict[str, int],
|
display_doc_id_to_rank_map: DocumentIdOrderMapping,
|
||||||
):
|
):
|
||||||
self.context_docs = context_docs
|
self.context_docs = context_docs
|
||||||
self.doc_id_to_rank_map = doc_id_to_rank_map
|
self.final_doc_id_to_rank_map = final_doc_id_to_rank_map
|
||||||
self.display_doc_order_dict = display_doc_order_dict
|
self.display_doc_id_to_rank_map = display_doc_id_to_rank_map
|
||||||
self.citation_processor = CitationProcessor(
|
self.citation_processor = CitationProcessor(
|
||||||
context_docs=self.context_docs,
|
context_docs=self.context_docs,
|
||||||
doc_id_to_rank_map=self.doc_id_to_rank_map,
|
final_doc_id_to_rank_map=self.final_doc_id_to_rank_map,
|
||||||
display_doc_order_dict=self.display_doc_order_dict,
|
display_doc_id_to_rank_map=self.display_doc_id_to_rank_map,
|
||||||
)
|
)
|
||||||
self.processed_text = ""
|
self.processed_text = ""
|
||||||
self.citations: list[CitationInfo] = []
|
self.citations: list[CitationInfo] = []
|
||||||
|
|
||||||
# TODO remove this after citation issue is resolved
|
# TODO remove this after citation issue is resolved
|
||||||
logger.debug(f"Document to ranking map {self.doc_id_to_rank_map}")
|
logger.debug(f"Document to ranking map {self.final_doc_id_to_rank_map}")
|
||||||
|
|
||||||
def handle_response_part(
|
def handle_response_part(
|
||||||
self,
|
self,
|
||||||
|
@@ -21,20 +21,19 @@ class CitationProcessor:
|
|||||||
def __init__(
|
def __init__(
|
||||||
self,
|
self,
|
||||||
context_docs: list[LlmDoc],
|
context_docs: list[LlmDoc],
|
||||||
doc_id_to_rank_map: DocumentIdOrderMapping,
|
final_doc_id_to_rank_map: DocumentIdOrderMapping,
|
||||||
display_doc_order_dict: dict[str, int],
|
display_doc_id_to_rank_map: DocumentIdOrderMapping,
|
||||||
stop_stream: str | None = STOP_STREAM_PAT,
|
stop_stream: str | None = STOP_STREAM_PAT,
|
||||||
):
|
):
|
||||||
self.context_docs = context_docs
|
self.context_docs = context_docs
|
||||||
self.doc_id_to_rank_map = doc_id_to_rank_map
|
self.final_doc_id_to_rank_map = final_doc_id_to_rank_map
|
||||||
|
self.display_doc_id_to_rank_map = display_doc_id_to_rank_map
|
||||||
self.stop_stream = stop_stream
|
self.stop_stream = stop_stream
|
||||||
self.order_mapping = doc_id_to_rank_map.order_mapping
|
self.final_order_mapping = final_doc_id_to_rank_map.order_mapping
|
||||||
self.display_doc_order_dict = (
|
self.display_order_mapping = display_doc_id_to_rank_map.order_mapping
|
||||||
display_doc_order_dict # original order of docs to displayed to user
|
|
||||||
)
|
|
||||||
self.llm_out = ""
|
self.llm_out = ""
|
||||||
self.max_citation_num = len(context_docs)
|
self.max_citation_num = len(context_docs)
|
||||||
self.citation_order: list[int] = []
|
self.citation_order: list[int] = [] # order of citations in the LLM output
|
||||||
self.curr_segment = ""
|
self.curr_segment = ""
|
||||||
self.cited_inds: set[int] = set()
|
self.cited_inds: set[int] = set()
|
||||||
self.hold = ""
|
self.hold = ""
|
||||||
@@ -93,29 +92,31 @@ class CitationProcessor:
|
|||||||
|
|
||||||
if 1 <= numerical_value <= self.max_citation_num:
|
if 1 <= numerical_value <= self.max_citation_num:
|
||||||
context_llm_doc = self.context_docs[numerical_value - 1]
|
context_llm_doc = self.context_docs[numerical_value - 1]
|
||||||
real_citation_num = self.order_mapping[context_llm_doc.document_id]
|
final_citation_num = self.final_order_mapping[
|
||||||
|
context_llm_doc.document_id
|
||||||
|
]
|
||||||
|
|
||||||
if real_citation_num not in self.citation_order:
|
if final_citation_num not in self.citation_order:
|
||||||
self.citation_order.append(real_citation_num)
|
self.citation_order.append(final_citation_num)
|
||||||
|
|
||||||
target_citation_num = (
|
citation_order_idx = (
|
||||||
self.citation_order.index(real_citation_num) + 1
|
self.citation_order.index(final_citation_num) + 1
|
||||||
)
|
)
|
||||||
|
|
||||||
# get the value that was displayed to user, should always
|
# get the value that was displayed to user, should always
|
||||||
# be in the display_doc_order_dict. But check anyways
|
# be in the display_doc_order_dict. But check anyways
|
||||||
if context_llm_doc.document_id in self.display_doc_order_dict:
|
if context_llm_doc.document_id in self.display_order_mapping:
|
||||||
displayed_citation_num = self.display_doc_order_dict[
|
displayed_citation_num = self.display_order_mapping[
|
||||||
context_llm_doc.document_id
|
context_llm_doc.document_id
|
||||||
]
|
]
|
||||||
else:
|
else:
|
||||||
displayed_citation_num = real_citation_num
|
displayed_citation_num = final_citation_num
|
||||||
logger.warning(
|
logger.warning(
|
||||||
f"Doc {context_llm_doc.document_id} not in display_doc_order_dict. Used LLM citation number instead."
|
f"Doc {context_llm_doc.document_id} not in display_doc_order_dict. Used LLM citation number instead."
|
||||||
)
|
)
|
||||||
|
|
||||||
# Skip consecutive citations of the same work
|
# Skip consecutive citations of the same work
|
||||||
if target_citation_num in self.current_citations:
|
if final_citation_num in self.current_citations:
|
||||||
start, end = citation.span()
|
start, end = citation.span()
|
||||||
real_start = length_to_add + start
|
real_start = length_to_add + start
|
||||||
diff = end - start
|
diff = end - start
|
||||||
@@ -134,8 +135,8 @@ class CitationProcessor:
|
|||||||
doc_id = int(match.group(1))
|
doc_id = int(match.group(1))
|
||||||
context_llm_doc = self.context_docs[doc_id - 1]
|
context_llm_doc = self.context_docs[doc_id - 1]
|
||||||
yield CitationInfo(
|
yield CitationInfo(
|
||||||
# stay with the original for now (order of LLM cites)
|
# citation_num is now the number post initial ranking, i.e. as displayed to user
|
||||||
citation_num=target_citation_num,
|
citation_num=displayed_citation_num,
|
||||||
document_id=context_llm_doc.document_id,
|
document_id=context_llm_doc.document_id,
|
||||||
)
|
)
|
||||||
except Exception as e:
|
except Exception as e:
|
||||||
@@ -151,13 +152,13 @@ class CitationProcessor:
|
|||||||
link = context_llm_doc.link
|
link = context_llm_doc.link
|
||||||
|
|
||||||
self.past_cite_count = len(self.llm_out)
|
self.past_cite_count = len(self.llm_out)
|
||||||
self.current_citations.append(target_citation_num)
|
self.current_citations.append(final_citation_num)
|
||||||
|
|
||||||
if target_citation_num not in self.cited_inds:
|
if citation_order_idx not in self.cited_inds:
|
||||||
self.cited_inds.add(target_citation_num)
|
self.cited_inds.add(citation_order_idx)
|
||||||
yield CitationInfo(
|
yield CitationInfo(
|
||||||
# stay with the original for now (order of LLM cites)
|
# citation number is now the one that was displayed to user
|
||||||
citation_num=target_citation_num,
|
citation_num=displayed_citation_num,
|
||||||
document_id=context_llm_doc.document_id,
|
document_id=context_llm_doc.document_id,
|
||||||
)
|
)
|
||||||
|
|
||||||
@@ -167,7 +168,6 @@ class CitationProcessor:
|
|||||||
self.curr_segment = (
|
self.curr_segment = (
|
||||||
self.curr_segment[: start + length_to_add]
|
self.curr_segment[: start + length_to_add]
|
||||||
+ f"[[{displayed_citation_num}]]({link})" # use the value that was displayed to user
|
+ f"[[{displayed_citation_num}]]({link})" # use the value that was displayed to user
|
||||||
# + f"[[{target_citation_num}]]({link})"
|
|
||||||
+ self.curr_segment[end + length_to_add :]
|
+ self.curr_segment[end + length_to_add :]
|
||||||
)
|
)
|
||||||
length_to_add += len(self.curr_segment) - prev_length
|
length_to_add += len(self.curr_segment) - prev_length
|
||||||
@@ -176,7 +176,6 @@ class CitationProcessor:
|
|||||||
self.curr_segment = (
|
self.curr_segment = (
|
||||||
self.curr_segment[: start + length_to_add]
|
self.curr_segment[: start + length_to_add]
|
||||||
+ f"[[{displayed_citation_num}]]()" # use the value that was displayed to user
|
+ f"[[{displayed_citation_num}]]()" # use the value that was displayed to user
|
||||||
# + f"[[{target_citation_num}]]()"
|
|
||||||
+ self.curr_segment[end + length_to_add :]
|
+ self.curr_segment[end + length_to_add :]
|
||||||
)
|
)
|
||||||
length_to_add += len(self.curr_segment) - prev_length
|
length_to_add += len(self.curr_segment) - prev_length
|
||||||
|
@@ -396,7 +396,7 @@ class SearchTool(Tool):
|
|||||||
@classmethod
|
@classmethod
|
||||||
def get_search_result(
|
def get_search_result(
|
||||||
cls, llm_call: LLMCall
|
cls, llm_call: LLMCall
|
||||||
) -> tuple[list[LlmDoc], dict[str, int]] | None:
|
) -> tuple[list[LlmDoc], list[LlmDoc]] | None:
|
||||||
"""
|
"""
|
||||||
Returns the final search results and a map of docs to their original search rank (which is what is displayed to user)
|
Returns the final search results and a map of docs to their original search rank (which is what is displayed to user)
|
||||||
"""
|
"""
|
||||||
@@ -404,7 +404,7 @@ class SearchTool(Tool):
|
|||||||
return None
|
return None
|
||||||
|
|
||||||
final_search_results = []
|
final_search_results = []
|
||||||
doc_id_to_original_search_rank_map = {}
|
initial_search_results = []
|
||||||
|
|
||||||
for yield_item in llm_call.tool_call_info:
|
for yield_item in llm_call.tool_call_info:
|
||||||
if (
|
if (
|
||||||
@@ -417,12 +417,11 @@ class SearchTool(Tool):
|
|||||||
and yield_item.id == ORIGINAL_CONTEXT_DOCUMENTS_ID
|
and yield_item.id == ORIGINAL_CONTEXT_DOCUMENTS_ID
|
||||||
):
|
):
|
||||||
search_contexts = yield_item.response.contexts
|
search_contexts = yield_item.response.contexts
|
||||||
original_doc_search_rank = 1
|
# original_doc_search_rank = 1
|
||||||
for idx, doc in enumerate(search_contexts):
|
for doc in search_contexts:
|
||||||
if doc.document_id not in doc_id_to_original_search_rank_map:
|
if doc.document_id not in initial_search_results:
|
||||||
doc_id_to_original_search_rank_map[
|
initial_search_results.append(doc)
|
||||||
doc.document_id
|
|
||||||
] = original_doc_search_rank
|
|
||||||
original_doc_search_rank += 1
|
|
||||||
|
|
||||||
return final_search_results, doc_id_to_original_search_rank_map
|
initial_search_results = cast(list[LlmDoc], initial_search_results)
|
||||||
|
|
||||||
|
return final_search_results, initial_search_results
|
||||||
|
@@ -68,11 +68,12 @@ def process_text(
|
|||||||
tokens: list[str], mock_data: tuple[list[LlmDoc], dict[str, int]]
|
tokens: list[str], mock_data: tuple[list[LlmDoc], dict[str, int]]
|
||||||
) -> tuple[str, list[CitationInfo]]:
|
) -> tuple[str, list[CitationInfo]]:
|
||||||
mock_docs, mock_doc_id_to_rank_map = mock_data
|
mock_docs, mock_doc_id_to_rank_map = mock_data
|
||||||
mapping = DocumentIdOrderMapping(order_mapping=mock_doc_id_to_rank_map)
|
final_mapping = DocumentIdOrderMapping(order_mapping=mock_doc_id_to_rank_map)
|
||||||
|
display_mapping = DocumentIdOrderMapping(order_mapping=mock_doc_id_to_rank_map)
|
||||||
processor = CitationProcessor(
|
processor = CitationProcessor(
|
||||||
context_docs=mock_docs,
|
context_docs=mock_docs,
|
||||||
doc_id_to_rank_map=mapping,
|
final_doc_id_to_rank_map=final_mapping,
|
||||||
display_doc_order_dict=mock_doc_id_to_rank_map,
|
display_doc_id_to_rank_map=display_mapping,
|
||||||
stop_stream=None,
|
stop_stream=None,
|
||||||
)
|
)
|
||||||
|
|
||||||
|
@@ -71,19 +71,22 @@ mock_doc_mapping_rerank = {
|
|||||||
|
|
||||||
|
|
||||||
@pytest.fixture
|
@pytest.fixture
|
||||||
def mock_data() -> tuple[list[LlmDoc], dict[str, int]]:
|
def mock_data() -> tuple[list[LlmDoc], dict[str, int], dict[str, int]]:
|
||||||
return mock_docs, mock_doc_mapping
|
return mock_docs, mock_doc_mapping, mock_doc_mapping_rerank
|
||||||
|
|
||||||
|
|
||||||
def process_text(
|
def process_text(
|
||||||
tokens: list[str], mock_data: tuple[list[LlmDoc], dict[str, int]]
|
tokens: list[str], mock_data: tuple[list[LlmDoc], dict[str, int], dict[str, int]]
|
||||||
) -> tuple[str, list[CitationInfo]]:
|
) -> tuple[str, list[CitationInfo]]:
|
||||||
mock_docs, mock_doc_id_to_rank_map = mock_data
|
mock_docs, mock_doc_id_to_rank_map, mock_doc_id_to_rank_map_rerank = mock_data
|
||||||
mapping = DocumentIdOrderMapping(order_mapping=mock_doc_id_to_rank_map)
|
final_mapping = DocumentIdOrderMapping(order_mapping=mock_doc_id_to_rank_map)
|
||||||
|
display_mapping = DocumentIdOrderMapping(
|
||||||
|
order_mapping=mock_doc_id_to_rank_map_rerank
|
||||||
|
)
|
||||||
processor = CitationProcessor(
|
processor = CitationProcessor(
|
||||||
context_docs=mock_docs,
|
context_docs=mock_docs,
|
||||||
doc_id_to_rank_map=mapping,
|
final_doc_id_to_rank_map=final_mapping,
|
||||||
display_doc_order_dict=mock_doc_mapping_rerank,
|
display_doc_id_to_rank_map=display_mapping,
|
||||||
stop_stream=None,
|
stop_stream=None,
|
||||||
)
|
)
|
||||||
|
|
||||||
@@ -115,7 +118,7 @@ def process_text(
|
|||||||
],
|
],
|
||||||
)
|
)
|
||||||
def test_citation_substitution(
|
def test_citation_substitution(
|
||||||
mock_data: tuple[list[LlmDoc], dict[str, int]],
|
mock_data: tuple[list[LlmDoc], dict[str, int], dict[str, int]],
|
||||||
test_name: str,
|
test_name: str,
|
||||||
input_tokens: list[str],
|
input_tokens: list[str],
|
||||||
expected_text: str,
|
expected_text: str,
|
||||||
|
Reference in New Issue
Block a user