diff --git a/backend/danswer/llm/answering/stream_processing/citation_processing.py b/backend/danswer/llm/answering/stream_processing/citation_processing.py index 950ad2078..1333d2112 100644 --- a/backend/danswer/llm/answering/stream_processing/citation_processing.py +++ b/backend/danswer/llm/answering/stream_processing/citation_processing.py @@ -67,9 +67,9 @@ class CitationProcessor: if piece_that_comes_after == "\n" and in_code_block(self.llm_out): self.curr_segment = self.curr_segment.replace("```", "```plaintext") - citation_pattern = r"\[(\d+)\]" + citation_pattern = r"\[(\d+)\]|\[\[(\d+)\]\]" citations_found = list(re.finditer(citation_pattern, self.curr_segment)) - possible_citation_pattern = r"(\[\d*$)" # [1, [, etc + possible_citation_pattern = r"(\[+\d*$)" possible_citation_found = re.search( possible_citation_pattern, self.curr_segment ) @@ -77,13 +77,15 @@ class CitationProcessor: if len(citations_found) == 0 and len(self.llm_out) - self.past_cite_count > 5: self.current_citations = [] - result = "" # Initialize result here + result = "" if citations_found and not in_code_block(self.llm_out): last_citation_end = 0 length_to_add = 0 while len(citations_found) > 0: citation = citations_found.pop(0) - numerical_value = int(citation.group(1)) + numerical_value = int( + next(group for group in citation.groups() if group is not None) + ) if 1 <= numerical_value <= self.max_citation_num: context_llm_doc = self.context_docs[numerical_value - 1] @@ -131,14 +133,6 @@ class CitationProcessor: link = context_llm_doc.link - # Replace the citation in the current segment - start, end = citation.span() - self.curr_segment = ( - self.curr_segment[: start + length_to_add] - + f"[{target_citation_num}]" - + self.curr_segment[end + length_to_add :] - ) - self.past_cite_count = len(self.llm_out) self.current_citations.append(target_citation_num) @@ -149,6 +143,7 @@ class CitationProcessor: document_id=context_llm_doc.document_id, ) + start, end = citation.span() if link: prev_length = len(self.curr_segment) self.curr_segment = ( diff --git a/backend/tests/unit/danswer/llm/answering/stream_processing/test_citation_processing.py b/backend/tests/unit/danswer/llm/answering/stream_processing/test_citation_processing.py index 13e6fd73b..386f1d25e 100644 --- a/backend/tests/unit/danswer/llm/answering/stream_processing/test_citation_processing.py +++ b/backend/tests/unit/danswer/llm/answering/stream_processing/test_citation_processing.py @@ -385,6 +385,16 @@ def process_text( "Here is some text[[1]](https://0.com). Some other text", ["doc_0"], ), + # ['To', ' set', ' up', ' D', 'answer', ',', ' if', ' you', ' are', ' running', ' it', ' yourself', ' and', + # ' need', ' access', ' to', ' certain', ' features', ' like', ' auto', '-sync', 'ing', ' document', + # '-level', ' access', ' permissions', ',', ' you', ' should', ' reach', ' out', ' to', ' the', ' D', + # 'answer', ' team', ' to', ' receive', ' access', ' [[', '4', ']].', ''] + ( + "Unique tokens with double brackets and a single token that ends the citation and has characters after it.", + ["... to receive access", " [[", "1", "]].", ""], + "... to receive access [[1]](https://0.com).", + ["doc_0"], + ), ], ) def test_citation_extraction(