Don't replace citations in code blocks (#911)

This commit is contained in:
Yuhong Sun
2024-01-05 23:32:28 -08:00
committed by GitHub
parent 885e698d5d
commit 49415e4615
2 changed files with 10 additions and 1 deletions

View File

@@ -29,6 +29,7 @@ from danswer.prompts.chat_prompts import DEFAULT_IGNORE_STATEMENT
from danswer.prompts.chat_prompts import NO_CITATION_STATEMENT from danswer.prompts.chat_prompts import NO_CITATION_STATEMENT
from danswer.prompts.chat_prompts import REQUIRE_CITATION_STATEMENT from danswer.prompts.chat_prompts import REQUIRE_CITATION_STATEMENT
from danswer.prompts.constants import CODE_BLOCK_PAT from danswer.prompts.constants import CODE_BLOCK_PAT
from danswer.prompts.constants import TRIPLE_BACKTICK
from danswer.prompts.direct_qa_prompts import LANGUAGE_HINT from danswer.prompts.direct_qa_prompts import LANGUAGE_HINT
from danswer.prompts.prompt_utils import get_current_llm_day_time from danswer.prompts.prompt_utils import get_current_llm_day_time
@@ -405,11 +406,17 @@ def drop_messages_history_overflow(
return prompt return prompt
def in_code_block(llm_text: str) -> bool:
count = llm_text.count(TRIPLE_BACKTICK)
return count % 2 != 0
def extract_citations_from_stream( def extract_citations_from_stream(
tokens: Iterator[str], tokens: Iterator[str],
context_docs: list[LlmDoc], context_docs: list[LlmDoc],
doc_id_to_rank_map: dict[str, int], doc_id_to_rank_map: dict[str, int],
) -> Iterator[DanswerAnswerPiece | CitationInfo]: ) -> Iterator[DanswerAnswerPiece | CitationInfo]:
llm_out = ""
max_citation_num = len(context_docs) max_citation_num = len(context_docs)
curr_segment = "" curr_segment = ""
prepend_bracket = False prepend_bracket = False
@@ -422,6 +429,7 @@ def extract_citations_from_stream(
prepend_bracket = False prepend_bracket = False
curr_segment += token curr_segment += token
llm_out += token
possible_citation_pattern = r"(\[\d*$)" # [1, [, etc possible_citation_pattern = r"(\[\d*$)" # [1, [, etc
possible_citation_found = re.search(possible_citation_pattern, curr_segment) possible_citation_found = re.search(possible_citation_pattern, curr_segment)
@@ -429,7 +437,7 @@ def extract_citations_from_stream(
citation_pattern = r"\[(\d+)\]" # [1], [2] etc citation_pattern = r"\[(\d+)\]" # [1], [2] etc
citation_found = re.search(citation_pattern, curr_segment) citation_found = re.search(citation_pattern, curr_segment)
if citation_found: if citation_found and not in_code_block(llm_out):
numerical_value = int(citation_found.group(1)) numerical_value = int(citation_found.group(1))
if 1 <= numerical_value <= max_citation_num: if 1 <= numerical_value <= max_citation_num:
context_llm_doc = context_docs[ context_llm_doc = context_docs[

View File

@@ -1,5 +1,6 @@
GENERAL_SEP_PAT = "--------------" # Same length as Langchain's separator GENERAL_SEP_PAT = "--------------" # Same length as Langchain's separator
CODE_BLOCK_PAT = "```\n{}\n```" CODE_BLOCK_PAT = "```\n{}\n```"
TRIPLE_BACKTICK = "```"
QUESTION_PAT = "Query:" QUESTION_PAT = "Query:"
FINAL_QUERY_PAT = "Final Query:" FINAL_QUERY_PAT = "Final Query:"
THOUGHT_PAT = "Thought:" THOUGHT_PAT = "Thought:"