mirror of
https://github.com/danswer-ai/danswer.git
synced 2025-05-11 20:30:07 +02:00
195 lines
8.4 KiB
Python
195 lines
8.4 KiB
Python
import re
|
|
from collections.abc import Generator
|
|
|
|
from onyx.chat.models import CitationInfo
|
|
from onyx.chat.models import LlmDoc
|
|
from onyx.chat.models import OnyxAnswerPiece
|
|
from onyx.chat.stream_processing.utils import DocumentIdOrderMapping
|
|
from onyx.configs.chat_configs import STOP_STREAM_PAT
|
|
from onyx.prompts.constants import TRIPLE_BACKTICK
|
|
from onyx.utils.logger import setup_logger
|
|
|
|
logger = setup_logger()
|
|
|
|
|
|
def in_code_block(llm_text: str) -> bool:
|
|
count = llm_text.count(TRIPLE_BACKTICK)
|
|
return count % 2 != 0
|
|
|
|
|
|
class CitationProcessor:
|
|
def __init__(
|
|
self,
|
|
context_docs: list[LlmDoc],
|
|
final_doc_id_to_rank_map: DocumentIdOrderMapping,
|
|
display_doc_id_to_rank_map: DocumentIdOrderMapping,
|
|
stop_stream: str | None = STOP_STREAM_PAT,
|
|
):
|
|
self.context_docs = context_docs
|
|
self.final_doc_id_to_rank_map = final_doc_id_to_rank_map
|
|
self.display_doc_id_to_rank_map = display_doc_id_to_rank_map
|
|
self.stop_stream = stop_stream
|
|
self.final_order_mapping = final_doc_id_to_rank_map.order_mapping
|
|
self.display_order_mapping = display_doc_id_to_rank_map.order_mapping
|
|
self.llm_out = ""
|
|
self.max_citation_num = len(context_docs)
|
|
self.citation_order: list[int] = [] # order of citations in the LLM output
|
|
self.curr_segment = ""
|
|
self.cited_inds: set[int] = set()
|
|
self.hold = ""
|
|
self.current_citations: list[int] = []
|
|
self.past_cite_count = 0
|
|
|
|
def process_token(
|
|
self, token: str | None
|
|
) -> Generator[OnyxAnswerPiece | CitationInfo, None, None]:
|
|
# None -> end of stream
|
|
if token is None:
|
|
yield OnyxAnswerPiece(answer_piece=self.curr_segment)
|
|
return
|
|
|
|
if self.stop_stream:
|
|
next_hold = self.hold + token
|
|
if self.stop_stream in next_hold:
|
|
return
|
|
if next_hold == self.stop_stream[: len(next_hold)]:
|
|
self.hold = next_hold
|
|
return
|
|
token = next_hold
|
|
self.hold = ""
|
|
|
|
self.curr_segment += token
|
|
self.llm_out += token
|
|
|
|
# Handle code blocks without language tags
|
|
if "`" in self.curr_segment:
|
|
if self.curr_segment.endswith("`"):
|
|
pass
|
|
elif "```" in self.curr_segment:
|
|
piece_that_comes_after = self.curr_segment.split("```")[1][0]
|
|
if piece_that_comes_after == "\n" and in_code_block(self.llm_out):
|
|
self.curr_segment = self.curr_segment.replace("```", "```plaintext")
|
|
|
|
citation_pattern = r"\[(\d+)\]|\[\[(\d+)\]\]" # [1], [[1]], etc.
|
|
citations_found = list(re.finditer(citation_pattern, self.curr_segment))
|
|
possible_citation_pattern = r"(\[+\d*$)" # [1, [, [[, [[2, etc.
|
|
possible_citation_found = re.search(
|
|
possible_citation_pattern, self.curr_segment
|
|
)
|
|
|
|
if len(citations_found) == 0 and len(self.llm_out) - self.past_cite_count > 5:
|
|
self.current_citations = []
|
|
|
|
result = ""
|
|
if citations_found and not in_code_block(self.llm_out):
|
|
last_citation_end = 0
|
|
length_to_add = 0
|
|
while len(citations_found) > 0:
|
|
citation = citations_found.pop(0)
|
|
numerical_value = int(
|
|
next(group for group in citation.groups() if group is not None)
|
|
)
|
|
|
|
if 1 <= numerical_value <= self.max_citation_num:
|
|
context_llm_doc = self.context_docs[numerical_value - 1]
|
|
final_citation_num = self.final_order_mapping[
|
|
context_llm_doc.document_id
|
|
]
|
|
|
|
if final_citation_num not in self.citation_order:
|
|
self.citation_order.append(final_citation_num)
|
|
|
|
citation_order_idx = (
|
|
self.citation_order.index(final_citation_num) + 1
|
|
)
|
|
|
|
# get the value that was displayed to user, should always
|
|
# be in the display_doc_order_dict. But check anyways
|
|
if context_llm_doc.document_id in self.display_order_mapping:
|
|
displayed_citation_num = self.display_order_mapping[
|
|
context_llm_doc.document_id
|
|
]
|
|
else:
|
|
displayed_citation_num = final_citation_num
|
|
logger.warning(
|
|
f"Doc {context_llm_doc.document_id} not in display_doc_order_dict. Used LLM citation number instead."
|
|
)
|
|
|
|
# Skip consecutive citations of the same work
|
|
if final_citation_num in self.current_citations:
|
|
start, end = citation.span()
|
|
real_start = length_to_add + start
|
|
diff = end - start
|
|
self.curr_segment = (
|
|
self.curr_segment[: length_to_add + start]
|
|
+ self.curr_segment[real_start + diff :]
|
|
)
|
|
length_to_add -= diff
|
|
continue
|
|
|
|
# Handle edge case where LLM outputs citation itself
|
|
if self.curr_segment.startswith("[["):
|
|
match = re.match(r"\[\[(\d+)\]\]", self.curr_segment)
|
|
if match:
|
|
try:
|
|
doc_id = int(match.group(1))
|
|
context_llm_doc = self.context_docs[doc_id - 1]
|
|
yield CitationInfo(
|
|
# citation_num is now the number post initial ranking, i.e. as displayed to user
|
|
citation_num=displayed_citation_num,
|
|
document_id=context_llm_doc.document_id,
|
|
)
|
|
except Exception as e:
|
|
logger.warning(
|
|
f"Manual LLM citation didn't properly cite documents {e}"
|
|
)
|
|
else:
|
|
logger.warning(
|
|
"Manual LLM citation wasn't able to close brackets"
|
|
)
|
|
continue
|
|
|
|
link = context_llm_doc.link
|
|
|
|
self.past_cite_count = len(self.llm_out)
|
|
self.current_citations.append(final_citation_num)
|
|
|
|
if citation_order_idx not in self.cited_inds:
|
|
self.cited_inds.add(citation_order_idx)
|
|
yield CitationInfo(
|
|
# citation number is now the one that was displayed to user
|
|
citation_num=displayed_citation_num,
|
|
document_id=context_llm_doc.document_id,
|
|
)
|
|
|
|
start, end = citation.span()
|
|
if link:
|
|
prev_length = len(self.curr_segment)
|
|
self.curr_segment = (
|
|
self.curr_segment[: start + length_to_add]
|
|
+ f"[[{displayed_citation_num}]]({link})" # use the value that was displayed to user
|
|
+ self.curr_segment[end + length_to_add :]
|
|
)
|
|
length_to_add += len(self.curr_segment) - prev_length
|
|
else:
|
|
prev_length = len(self.curr_segment)
|
|
self.curr_segment = (
|
|
self.curr_segment[: start + length_to_add]
|
|
+ f"[[{displayed_citation_num}]]()" # use the value that was displayed to user
|
|
+ self.curr_segment[end + length_to_add :]
|
|
)
|
|
length_to_add += len(self.curr_segment) - prev_length
|
|
|
|
last_citation_end = end + length_to_add
|
|
|
|
if last_citation_end > 0:
|
|
result += self.curr_segment[:last_citation_end]
|
|
self.curr_segment = self.curr_segment[last_citation_end:]
|
|
|
|
if not possible_citation_found:
|
|
result += self.curr_segment
|
|
self.curr_segment = ""
|
|
|
|
if result:
|
|
yield OnyxAnswerPiece(answer_piece=result)
|