mirror of
https://github.com/danswer-ai/danswer.git
synced 2025-07-28 05:43:33 +02:00
Don't replace citations in code blocks (#911)
This commit is contained in:
@@ -29,6 +29,7 @@ from danswer.prompts.chat_prompts import DEFAULT_IGNORE_STATEMENT
|
|||||||
from danswer.prompts.chat_prompts import NO_CITATION_STATEMENT
|
from danswer.prompts.chat_prompts import NO_CITATION_STATEMENT
|
||||||
from danswer.prompts.chat_prompts import REQUIRE_CITATION_STATEMENT
|
from danswer.prompts.chat_prompts import REQUIRE_CITATION_STATEMENT
|
||||||
from danswer.prompts.constants import CODE_BLOCK_PAT
|
from danswer.prompts.constants import CODE_BLOCK_PAT
|
||||||
|
from danswer.prompts.constants import TRIPLE_BACKTICK
|
||||||
from danswer.prompts.direct_qa_prompts import LANGUAGE_HINT
|
from danswer.prompts.direct_qa_prompts import LANGUAGE_HINT
|
||||||
from danswer.prompts.prompt_utils import get_current_llm_day_time
|
from danswer.prompts.prompt_utils import get_current_llm_day_time
|
||||||
|
|
||||||
@@ -405,11 +406,17 @@ def drop_messages_history_overflow(
|
|||||||
return prompt
|
return prompt
|
||||||
|
|
||||||
|
|
||||||
|
def in_code_block(llm_text: str) -> bool:
|
||||||
|
count = llm_text.count(TRIPLE_BACKTICK)
|
||||||
|
return count % 2 != 0
|
||||||
|
|
||||||
|
|
||||||
def extract_citations_from_stream(
|
def extract_citations_from_stream(
|
||||||
tokens: Iterator[str],
|
tokens: Iterator[str],
|
||||||
context_docs: list[LlmDoc],
|
context_docs: list[LlmDoc],
|
||||||
doc_id_to_rank_map: dict[str, int],
|
doc_id_to_rank_map: dict[str, int],
|
||||||
) -> Iterator[DanswerAnswerPiece | CitationInfo]:
|
) -> Iterator[DanswerAnswerPiece | CitationInfo]:
|
||||||
|
llm_out = ""
|
||||||
max_citation_num = len(context_docs)
|
max_citation_num = len(context_docs)
|
||||||
curr_segment = ""
|
curr_segment = ""
|
||||||
prepend_bracket = False
|
prepend_bracket = False
|
||||||
@@ -422,6 +429,7 @@ def extract_citations_from_stream(
|
|||||||
prepend_bracket = False
|
prepend_bracket = False
|
||||||
|
|
||||||
curr_segment += token
|
curr_segment += token
|
||||||
|
llm_out += token
|
||||||
|
|
||||||
possible_citation_pattern = r"(\[\d*$)" # [1, [, etc
|
possible_citation_pattern = r"(\[\d*$)" # [1, [, etc
|
||||||
possible_citation_found = re.search(possible_citation_pattern, curr_segment)
|
possible_citation_found = re.search(possible_citation_pattern, curr_segment)
|
||||||
@@ -429,7 +437,7 @@ def extract_citations_from_stream(
|
|||||||
citation_pattern = r"\[(\d+)\]" # [1], [2] etc
|
citation_pattern = r"\[(\d+)\]" # [1], [2] etc
|
||||||
citation_found = re.search(citation_pattern, curr_segment)
|
citation_found = re.search(citation_pattern, curr_segment)
|
||||||
|
|
||||||
if citation_found:
|
if citation_found and not in_code_block(llm_out):
|
||||||
numerical_value = int(citation_found.group(1))
|
numerical_value = int(citation_found.group(1))
|
||||||
if 1 <= numerical_value <= max_citation_num:
|
if 1 <= numerical_value <= max_citation_num:
|
||||||
context_llm_doc = context_docs[
|
context_llm_doc = context_docs[
|
||||||
|
@@ -1,5 +1,6 @@
|
|||||||
GENERAL_SEP_PAT = "--------------" # Same length as Langchain's separator
|
GENERAL_SEP_PAT = "--------------" # Same length as Langchain's separator
|
||||||
CODE_BLOCK_PAT = "```\n{}\n```"
|
CODE_BLOCK_PAT = "```\n{}\n```"
|
||||||
|
TRIPLE_BACKTICK = "```"
|
||||||
QUESTION_PAT = "Query:"
|
QUESTION_PAT = "Query:"
|
||||||
FINAL_QUERY_PAT = "Final Query:"
|
FINAL_QUERY_PAT = "Final Query:"
|
||||||
THOUGHT_PAT = "Thought:"
|
THOUGHT_PAT = "Thought:"
|
||||||
|
Reference in New Issue
Block a user