alignment & renaming of objects for initial (displayed) ranking and re-ranking/validation citations

- renamed post-reranking/validation citation information consistently to final_... (example: doc_id_to_rank_map -> final_doc_id_to_rank_map)
 - changed and renamed objects containing initial ranking information (now: display_...) consistent with final rankings (final_...). Specifically, {} to [] for displayed_search_results
 - for CitationInfo, changed citation_num from 'x-th citation in response stream' to the initial position of the doc [NOTE: test implications]
-  changed tests:
    onyx/backend/tests/unit/onyx/chat/stream_processing/test_citation_processing.py
    onyx/backend/tests/unit/onyx/chat/stream_processing/test_citation_substitution.py
This commit is contained in:
joachim-danswer
2024-12-18 14:00:42 -08:00
committed by Chris Weaver
parent 27699c8216
commit 8750f14647
6 changed files with 64 additions and 60 deletions

View File

@@ -22,7 +22,9 @@ from onyx.chat.stream_processing.answer_response_handler import (
from onyx.chat.stream_processing.answer_response_handler import ( from onyx.chat.stream_processing.answer_response_handler import (
DummyAnswerResponseHandler, DummyAnswerResponseHandler,
) )
from onyx.chat.stream_processing.utils import map_document_id_order from onyx.chat.stream_processing.utils import (
map_document_id_order,
)
from onyx.chat.tool_handling.tool_response_handler import ToolResponseHandler from onyx.chat.tool_handling.tool_response_handler import ToolResponseHandler
from onyx.file_store.utils import InMemoryChatFile from onyx.file_store.utils import InMemoryChatFile
from onyx.llm.interfaces import LLM from onyx.llm.interfaces import LLM
@@ -206,9 +208,9 @@ class Answer:
# + figure out what the next LLM call should be # + figure out what the next LLM call should be
tool_call_handler = ToolResponseHandler(current_llm_call.tools) tool_call_handler = ToolResponseHandler(current_llm_call.tools)
search_result, displayed_search_results_map = SearchTool.get_search_result( final_search_results, displayed_search_results = SearchTool.get_search_result(
current_llm_call current_llm_call
) or ([], {}) ) or ([], [])
# Quotes are no longer supported # Quotes are no longer supported
# answer_handler: AnswerResponseHandler # answer_handler: AnswerResponseHandler
@@ -224,9 +226,9 @@ class Answer:
# else: # else:
# raise ValueError("No answer style config provided") # raise ValueError("No answer style config provided")
answer_handler = CitationResponseHandler( answer_handler = CitationResponseHandler(
context_docs=search_result, context_docs=final_search_results,
doc_id_to_rank_map=map_document_id_order(search_result), final_doc_id_to_rank_map=map_document_id_order(final_search_results),
display_doc_order_dict=displayed_search_results_map, display_doc_id_to_rank_map=map_document_id_order(displayed_search_results),
) )
response_handler_manager = LLMResponseHandlerManager( response_handler_manager = LLMResponseHandlerManager(

View File

@@ -37,22 +37,22 @@ class CitationResponseHandler(AnswerResponseHandler):
def __init__( def __init__(
self, self,
context_docs: list[LlmDoc], context_docs: list[LlmDoc],
doc_id_to_rank_map: DocumentIdOrderMapping, final_doc_id_to_rank_map: DocumentIdOrderMapping,
display_doc_order_dict: dict[str, int], display_doc_id_to_rank_map: DocumentIdOrderMapping,
): ):
self.context_docs = context_docs self.context_docs = context_docs
self.doc_id_to_rank_map = doc_id_to_rank_map self.final_doc_id_to_rank_map = final_doc_id_to_rank_map
self.display_doc_order_dict = display_doc_order_dict self.display_doc_id_to_rank_map = display_doc_id_to_rank_map
self.citation_processor = CitationProcessor( self.citation_processor = CitationProcessor(
context_docs=self.context_docs, context_docs=self.context_docs,
doc_id_to_rank_map=self.doc_id_to_rank_map, final_doc_id_to_rank_map=self.final_doc_id_to_rank_map,
display_doc_order_dict=self.display_doc_order_dict, display_doc_id_to_rank_map=self.display_doc_id_to_rank_map,
) )
self.processed_text = "" self.processed_text = ""
self.citations: list[CitationInfo] = [] self.citations: list[CitationInfo] = []
# TODO remove this after citation issue is resolved # TODO remove this after citation issue is resolved
logger.debug(f"Document to ranking map {self.doc_id_to_rank_map}") logger.debug(f"Document to ranking map {self.final_doc_id_to_rank_map}")
def handle_response_part( def handle_response_part(
self, self,

View File

@@ -21,20 +21,19 @@ class CitationProcessor:
def __init__( def __init__(
self, self,
context_docs: list[LlmDoc], context_docs: list[LlmDoc],
doc_id_to_rank_map: DocumentIdOrderMapping, final_doc_id_to_rank_map: DocumentIdOrderMapping,
display_doc_order_dict: dict[str, int], display_doc_id_to_rank_map: DocumentIdOrderMapping,
stop_stream: str | None = STOP_STREAM_PAT, stop_stream: str | None = STOP_STREAM_PAT,
): ):
self.context_docs = context_docs self.context_docs = context_docs
self.doc_id_to_rank_map = doc_id_to_rank_map self.final_doc_id_to_rank_map = final_doc_id_to_rank_map
self.display_doc_id_to_rank_map = display_doc_id_to_rank_map
self.stop_stream = stop_stream self.stop_stream = stop_stream
self.order_mapping = doc_id_to_rank_map.order_mapping self.final_order_mapping = final_doc_id_to_rank_map.order_mapping
self.display_doc_order_dict = ( self.display_order_mapping = display_doc_id_to_rank_map.order_mapping
display_doc_order_dict # original order of docs to displayed to user
)
self.llm_out = "" self.llm_out = ""
self.max_citation_num = len(context_docs) self.max_citation_num = len(context_docs)
self.citation_order: list[int] = [] self.citation_order: list[int] = [] # order of citations in the LLM output
self.curr_segment = "" self.curr_segment = ""
self.cited_inds: set[int] = set() self.cited_inds: set[int] = set()
self.hold = "" self.hold = ""
@@ -93,29 +92,31 @@ class CitationProcessor:
if 1 <= numerical_value <= self.max_citation_num: if 1 <= numerical_value <= self.max_citation_num:
context_llm_doc = self.context_docs[numerical_value - 1] context_llm_doc = self.context_docs[numerical_value - 1]
real_citation_num = self.order_mapping[context_llm_doc.document_id] final_citation_num = self.final_order_mapping[
context_llm_doc.document_id
]
if real_citation_num not in self.citation_order: if final_citation_num not in self.citation_order:
self.citation_order.append(real_citation_num) self.citation_order.append(final_citation_num)
target_citation_num = ( citation_order_idx = (
self.citation_order.index(real_citation_num) + 1 self.citation_order.index(final_citation_num) + 1
) )
# get the value that was displayed to user, should always # get the value that was displayed to user, should always
# be in the display_doc_order_dict. But check anyways # be in the display_doc_order_dict. But check anyways
if context_llm_doc.document_id in self.display_doc_order_dict: if context_llm_doc.document_id in self.display_order_mapping:
displayed_citation_num = self.display_doc_order_dict[ displayed_citation_num = self.display_order_mapping[
context_llm_doc.document_id context_llm_doc.document_id
] ]
else: else:
displayed_citation_num = real_citation_num displayed_citation_num = final_citation_num
logger.warning( logger.warning(
f"Doc {context_llm_doc.document_id} not in display_doc_order_dict. Used LLM citation number instead." f"Doc {context_llm_doc.document_id} not in display_doc_order_dict. Used LLM citation number instead."
) )
# Skip consecutive citations of the same work # Skip consecutive citations of the same work
if target_citation_num in self.current_citations: if final_citation_num in self.current_citations:
start, end = citation.span() start, end = citation.span()
real_start = length_to_add + start real_start = length_to_add + start
diff = end - start diff = end - start
@@ -134,8 +135,8 @@ class CitationProcessor:
doc_id = int(match.group(1)) doc_id = int(match.group(1))
context_llm_doc = self.context_docs[doc_id - 1] context_llm_doc = self.context_docs[doc_id - 1]
yield CitationInfo( yield CitationInfo(
# stay with the original for now (order of LLM cites) # citation_num is now the number post initial ranking, i.e. as displayed to user
citation_num=target_citation_num, citation_num=displayed_citation_num,
document_id=context_llm_doc.document_id, document_id=context_llm_doc.document_id,
) )
except Exception as e: except Exception as e:
@@ -151,13 +152,13 @@ class CitationProcessor:
link = context_llm_doc.link link = context_llm_doc.link
self.past_cite_count = len(self.llm_out) self.past_cite_count = len(self.llm_out)
self.current_citations.append(target_citation_num) self.current_citations.append(final_citation_num)
if target_citation_num not in self.cited_inds: if citation_order_idx not in self.cited_inds:
self.cited_inds.add(target_citation_num) self.cited_inds.add(citation_order_idx)
yield CitationInfo( yield CitationInfo(
# stay with the original for now (order of LLM cites) # citation number is now the one that was displayed to user
citation_num=target_citation_num, citation_num=displayed_citation_num,
document_id=context_llm_doc.document_id, document_id=context_llm_doc.document_id,
) )
@@ -167,7 +168,6 @@ class CitationProcessor:
self.curr_segment = ( self.curr_segment = (
self.curr_segment[: start + length_to_add] self.curr_segment[: start + length_to_add]
+ f"[[{displayed_citation_num}]]({link})" # use the value that was displayed to user + f"[[{displayed_citation_num}]]({link})" # use the value that was displayed to user
# + f"[[{target_citation_num}]]({link})"
+ self.curr_segment[end + length_to_add :] + self.curr_segment[end + length_to_add :]
) )
length_to_add += len(self.curr_segment) - prev_length length_to_add += len(self.curr_segment) - prev_length
@@ -176,7 +176,6 @@ class CitationProcessor:
self.curr_segment = ( self.curr_segment = (
self.curr_segment[: start + length_to_add] self.curr_segment[: start + length_to_add]
+ f"[[{displayed_citation_num}]]()" # use the value that was displayed to user + f"[[{displayed_citation_num}]]()" # use the value that was displayed to user
# + f"[[{target_citation_num}]]()"
+ self.curr_segment[end + length_to_add :] + self.curr_segment[end + length_to_add :]
) )
length_to_add += len(self.curr_segment) - prev_length length_to_add += len(self.curr_segment) - prev_length

View File

@@ -396,7 +396,7 @@ class SearchTool(Tool):
@classmethod @classmethod
def get_search_result( def get_search_result(
cls, llm_call: LLMCall cls, llm_call: LLMCall
) -> tuple[list[LlmDoc], dict[str, int]] | None: ) -> tuple[list[LlmDoc], list[LlmDoc]] | None:
""" """
Returns the final search results and a map of docs to their original search rank (which is what is displayed to user) Returns the final search results and a map of docs to their original search rank (which is what is displayed to user)
""" """
@@ -404,7 +404,7 @@ class SearchTool(Tool):
return None return None
final_search_results = [] final_search_results = []
doc_id_to_original_search_rank_map = {} initial_search_results = []
for yield_item in llm_call.tool_call_info: for yield_item in llm_call.tool_call_info:
if ( if (
@@ -417,12 +417,11 @@ class SearchTool(Tool):
and yield_item.id == ORIGINAL_CONTEXT_DOCUMENTS_ID and yield_item.id == ORIGINAL_CONTEXT_DOCUMENTS_ID
): ):
search_contexts = yield_item.response.contexts search_contexts = yield_item.response.contexts
original_doc_search_rank = 1 # original_doc_search_rank = 1
for idx, doc in enumerate(search_contexts): for doc in search_contexts:
if doc.document_id not in doc_id_to_original_search_rank_map: if doc.document_id not in initial_search_results:
doc_id_to_original_search_rank_map[ initial_search_results.append(doc)
doc.document_id
] = original_doc_search_rank
original_doc_search_rank += 1
return final_search_results, doc_id_to_original_search_rank_map initial_search_results = cast(list[LlmDoc], initial_search_results)
return final_search_results, initial_search_results

View File

@@ -68,11 +68,12 @@ def process_text(
tokens: list[str], mock_data: tuple[list[LlmDoc], dict[str, int]] tokens: list[str], mock_data: tuple[list[LlmDoc], dict[str, int]]
) -> tuple[str, list[CitationInfo]]: ) -> tuple[str, list[CitationInfo]]:
mock_docs, mock_doc_id_to_rank_map = mock_data mock_docs, mock_doc_id_to_rank_map = mock_data
mapping = DocumentIdOrderMapping(order_mapping=mock_doc_id_to_rank_map) final_mapping = DocumentIdOrderMapping(order_mapping=mock_doc_id_to_rank_map)
display_mapping = DocumentIdOrderMapping(order_mapping=mock_doc_id_to_rank_map)
processor = CitationProcessor( processor = CitationProcessor(
context_docs=mock_docs, context_docs=mock_docs,
doc_id_to_rank_map=mapping, final_doc_id_to_rank_map=final_mapping,
display_doc_order_dict=mock_doc_id_to_rank_map, display_doc_id_to_rank_map=display_mapping,
stop_stream=None, stop_stream=None,
) )

View File

@@ -71,19 +71,22 @@ mock_doc_mapping_rerank = {
@pytest.fixture @pytest.fixture
def mock_data() -> tuple[list[LlmDoc], dict[str, int]]: def mock_data() -> tuple[list[LlmDoc], dict[str, int], dict[str, int]]:
return mock_docs, mock_doc_mapping return mock_docs, mock_doc_mapping, mock_doc_mapping_rerank
def process_text( def process_text(
tokens: list[str], mock_data: tuple[list[LlmDoc], dict[str, int]] tokens: list[str], mock_data: tuple[list[LlmDoc], dict[str, int], dict[str, int]]
) -> tuple[str, list[CitationInfo]]: ) -> tuple[str, list[CitationInfo]]:
mock_docs, mock_doc_id_to_rank_map = mock_data mock_docs, mock_doc_id_to_rank_map, mock_doc_id_to_rank_map_rerank = mock_data
mapping = DocumentIdOrderMapping(order_mapping=mock_doc_id_to_rank_map) final_mapping = DocumentIdOrderMapping(order_mapping=mock_doc_id_to_rank_map)
display_mapping = DocumentIdOrderMapping(
order_mapping=mock_doc_id_to_rank_map_rerank
)
processor = CitationProcessor( processor = CitationProcessor(
context_docs=mock_docs, context_docs=mock_docs,
doc_id_to_rank_map=mapping, final_doc_id_to_rank_map=final_mapping,
display_doc_order_dict=mock_doc_mapping_rerank, display_doc_id_to_rank_map=display_mapping,
stop_stream=None, stop_stream=None,
) )
@@ -115,7 +118,7 @@ def process_text(
], ],
) )
def test_citation_substitution( def test_citation_substitution(
mock_data: tuple[list[LlmDoc], dict[str, int]], mock_data: tuple[list[LlmDoc], dict[str, int], dict[str, int]],
test_name: str, test_name: str,
input_tokens: list[str], input_tokens: list[str],
expected_text: str, expected_text: str,