From 49415e46151d92099f763598771e74e30cc121b0 Mon Sep 17 00:00:00 2001 From: Yuhong Sun Date: Fri, 5 Jan 2024 23:32:28 -0800 Subject: [PATCH] Don't replace citations in code blocks (#911) --- backend/danswer/chat/chat_utils.py | 10 +++++++++- backend/danswer/prompts/constants.py | 1 + 2 files changed, 10 insertions(+), 1 deletion(-) diff --git a/backend/danswer/chat/chat_utils.py b/backend/danswer/chat/chat_utils.py index efc8323f4..4714a51e8 100644 --- a/backend/danswer/chat/chat_utils.py +++ b/backend/danswer/chat/chat_utils.py @@ -29,6 +29,7 @@ from danswer.prompts.chat_prompts import DEFAULT_IGNORE_STATEMENT from danswer.prompts.chat_prompts import NO_CITATION_STATEMENT from danswer.prompts.chat_prompts import REQUIRE_CITATION_STATEMENT from danswer.prompts.constants import CODE_BLOCK_PAT +from danswer.prompts.constants import TRIPLE_BACKTICK from danswer.prompts.direct_qa_prompts import LANGUAGE_HINT from danswer.prompts.prompt_utils import get_current_llm_day_time @@ -405,11 +406,17 @@ def drop_messages_history_overflow( return prompt +def in_code_block(llm_text: str) -> bool: + count = llm_text.count(TRIPLE_BACKTICK) + return count % 2 != 0 + + def extract_citations_from_stream( tokens: Iterator[str], context_docs: list[LlmDoc], doc_id_to_rank_map: dict[str, int], ) -> Iterator[DanswerAnswerPiece | CitationInfo]: + llm_out = "" max_citation_num = len(context_docs) curr_segment = "" prepend_bracket = False @@ -422,6 +429,7 @@ def extract_citations_from_stream( prepend_bracket = False curr_segment += token + llm_out += token possible_citation_pattern = r"(\[\d*$)" # [1, [, etc possible_citation_found = re.search(possible_citation_pattern, curr_segment) @@ -429,7 +437,7 @@ def extract_citations_from_stream( citation_pattern = r"\[(\d+)\]" # [1], [2] etc citation_found = re.search(citation_pattern, curr_segment) - if citation_found: + if citation_found and not in_code_block(llm_out): numerical_value = int(citation_found.group(1)) if 1 <= numerical_value <= max_citation_num: context_llm_doc = context_docs[ diff --git a/backend/danswer/prompts/constants.py b/backend/danswer/prompts/constants.py index 74e488aeb..5fb9dbf84 100644 --- a/backend/danswer/prompts/constants.py +++ b/backend/danswer/prompts/constants.py @@ -1,5 +1,6 @@ GENERAL_SEP_PAT = "--------------" # Same length as Langchain's separator CODE_BLOCK_PAT = "```\n{}\n```" +TRIPLE_BACKTICK = "```" QUESTION_PAT = "Query:" FINAL_QUERY_PAT = "Final Query:" THOUGHT_PAT = "Thought:"