Don't replace citations in code blocks (#911)

This commit is contained in:
Yuhong Sun 2024-01-05 23:32:28 -08:00 committed by GitHub
parent 885e698d5d
commit 49415e4615
No known key found for this signature in database
GPG Key ID: 4AEE18F83AFDEB23
2 changed files with 10 additions and 1 deletions

View File

@ -29,6 +29,7 @@ from danswer.prompts.chat_prompts import DEFAULT_IGNORE_STATEMENT
from danswer.prompts.chat_prompts import NO_CITATION_STATEMENT
from danswer.prompts.chat_prompts import REQUIRE_CITATION_STATEMENT
from danswer.prompts.constants import CODE_BLOCK_PAT
from danswer.prompts.constants import TRIPLE_BACKTICK
from danswer.prompts.direct_qa_prompts import LANGUAGE_HINT
from danswer.prompts.prompt_utils import get_current_llm_day_time
@ -405,11 +406,17 @@ def drop_messages_history_overflow(
return prompt
def in_code_block(llm_text: str) -> bool:
count = llm_text.count(TRIPLE_BACKTICK)
return count % 2 != 0
def extract_citations_from_stream(
tokens: Iterator[str],
context_docs: list[LlmDoc],
doc_id_to_rank_map: dict[str, int],
) -> Iterator[DanswerAnswerPiece | CitationInfo]:
llm_out = ""
max_citation_num = len(context_docs)
curr_segment = ""
prepend_bracket = False
@ -422,6 +429,7 @@ def extract_citations_from_stream(
prepend_bracket = False
curr_segment += token
llm_out += token
possible_citation_pattern = r"(\[\d*$)" # [1, [, etc
possible_citation_found = re.search(possible_citation_pattern, curr_segment)
@ -429,7 +437,7 @@ def extract_citations_from_stream(
citation_pattern = r"\[(\d+)\]" # [1], [2] etc
citation_found = re.search(citation_pattern, curr_segment)
if citation_found:
if citation_found and not in_code_block(llm_out):
numerical_value = int(citation_found.group(1))
if 1 <= numerical_value <= max_citation_num:
context_llm_doc = context_docs[

View File

@ -1,5 +1,6 @@
GENERAL_SEP_PAT = "--------------" # Same length as Langchain's separator
CODE_BLOCK_PAT = "```\n{}\n```"
TRIPLE_BACKTICK = "```"
QUESTION_PAT = "Query:"
FINAL_QUERY_PAT = "Final Query:"
THOUGHT_PAT = "Thought:"