alignment & renaming of objects for initial (displayed) ranking and re-ranking/validation citations

- renamed post-reranking/validation citation information consistently to final_... (example: doc_id_to_rank_map -> final_doc_id_to_rank_map)
 - changed and renamed objects containing initial ranking information (now: display_...) consistent with final rankings (final_...). Specifically, {} to [] for displayed_search_results
 - for CitationInfo, changed citation_num from 'x-th citation in response stream' to the initial position of the doc [NOTE: test implications]
-  changed tests:
    onyx/backend/tests/unit/onyx/chat/stream_processing/test_citation_processing.py
    onyx/backend/tests/unit/onyx/chat/stream_processing/test_citation_substitution.py
This commit is contained in:
joachim-danswer
2024-12-18 14:00:42 -08:00
committed by Chris Weaver
parent 27699c8216
commit 8750f14647
6 changed files with 64 additions and 60 deletions

View File

@ -21,20 +21,19 @@ class CitationProcessor:
def __init__(
self,
context_docs: list[LlmDoc],
doc_id_to_rank_map: DocumentIdOrderMapping,
display_doc_order_dict: dict[str, int],
final_doc_id_to_rank_map: DocumentIdOrderMapping,
display_doc_id_to_rank_map: DocumentIdOrderMapping,
stop_stream: str | None = STOP_STREAM_PAT,
):
self.context_docs = context_docs
self.doc_id_to_rank_map = doc_id_to_rank_map
self.final_doc_id_to_rank_map = final_doc_id_to_rank_map
self.display_doc_id_to_rank_map = display_doc_id_to_rank_map
self.stop_stream = stop_stream
self.order_mapping = doc_id_to_rank_map.order_mapping
self.display_doc_order_dict = (
display_doc_order_dict # original order of docs to displayed to user
)
self.final_order_mapping = final_doc_id_to_rank_map.order_mapping
self.display_order_mapping = display_doc_id_to_rank_map.order_mapping
self.llm_out = ""
self.max_citation_num = len(context_docs)
self.citation_order: list[int] = []
self.citation_order: list[int] = [] # order of citations in the LLM output
self.curr_segment = ""
self.cited_inds: set[int] = set()
self.hold = ""
@ -93,29 +92,31 @@ class CitationProcessor:
if 1 <= numerical_value <= self.max_citation_num:
context_llm_doc = self.context_docs[numerical_value - 1]
real_citation_num = self.order_mapping[context_llm_doc.document_id]
final_citation_num = self.final_order_mapping[
context_llm_doc.document_id
]
if real_citation_num not in self.citation_order:
self.citation_order.append(real_citation_num)
if final_citation_num not in self.citation_order:
self.citation_order.append(final_citation_num)
target_citation_num = (
self.citation_order.index(real_citation_num) + 1
citation_order_idx = (
self.citation_order.index(final_citation_num) + 1
)
# get the value that was displayed to user, should always
# be in the display_doc_order_dict. But check anyways
if context_llm_doc.document_id in self.display_doc_order_dict:
displayed_citation_num = self.display_doc_order_dict[
if context_llm_doc.document_id in self.display_order_mapping:
displayed_citation_num = self.display_order_mapping[
context_llm_doc.document_id
]
else:
displayed_citation_num = real_citation_num
displayed_citation_num = final_citation_num
logger.warning(
f"Doc {context_llm_doc.document_id} not in display_doc_order_dict. Used LLM citation number instead."
)
# Skip consecutive citations of the same work
if target_citation_num in self.current_citations:
if final_citation_num in self.current_citations:
start, end = citation.span()
real_start = length_to_add + start
diff = end - start
@ -134,8 +135,8 @@ class CitationProcessor:
doc_id = int(match.group(1))
context_llm_doc = self.context_docs[doc_id - 1]
yield CitationInfo(
# stay with the original for now (order of LLM cites)
citation_num=target_citation_num,
# citation_num is now the number post initial ranking, i.e. as displayed to user
citation_num=displayed_citation_num,
document_id=context_llm_doc.document_id,
)
except Exception as e:
@ -151,13 +152,13 @@ class CitationProcessor:
link = context_llm_doc.link
self.past_cite_count = len(self.llm_out)
self.current_citations.append(target_citation_num)
self.current_citations.append(final_citation_num)
if target_citation_num not in self.cited_inds:
self.cited_inds.add(target_citation_num)
if citation_order_idx not in self.cited_inds:
self.cited_inds.add(citation_order_idx)
yield CitationInfo(
# stay with the original for now (order of LLM cites)
citation_num=target_citation_num,
# citation number is now the one that was displayed to user
citation_num=displayed_citation_num,
document_id=context_llm_doc.document_id,
)
@ -167,7 +168,6 @@ class CitationProcessor:
self.curr_segment = (
self.curr_segment[: start + length_to_add]
+ f"[[{displayed_citation_num}]]({link})" # use the value that was displayed to user
# + f"[[{target_citation_num}]]({link})"
+ self.curr_segment[end + length_to_add :]
)
length_to_add += len(self.curr_segment) - prev_length
@ -176,7 +176,6 @@ class CitationProcessor:
self.curr_segment = (
self.curr_segment[: start + length_to_add]
+ f"[[{displayed_citation_num}]]()" # use the value that was displayed to user
# + f"[[{target_citation_num}]]()"
+ self.curr_segment[end + length_to_add :]
)
length_to_add += len(self.curr_segment) - prev_length