diff --git a/backend/danswer/llm/answering/stream_processing/citation_processing.py b/backend/danswer/llm/answering/stream_processing/citation_processing.py index d1d1eb078..de80b6f67 100644 --- a/backend/danswer/llm/answering/stream_processing/citation_processing.py +++ b/backend/danswer/llm/answering/stream_processing/citation_processing.py @@ -125,6 +125,30 @@ def extract_citations_from_stream( length_to_add -= diff continue + # Handle edge case where LLM outputs citation itself + # by allowing it to generate citations on its own. + if curr_segment.startswith("[["): + match = re.match(r"\[\[(\d+)\]\]", curr_segment) + if match: + try: + doc_id = int(match.group(1)) + context_llm_doc = context_docs[doc_id - 1] + yield CitationInfo( + citation_num=target_citation_num, + document_id=context_llm_doc.document_id, + ) + except Exception as e: + logger.warning( + f"Manual LLM citation didn't properly cite documents {e}" + ) + else: + # Will continue attempt on next loops + logger.warning( + "Manual LLM citation wasn't able to close brackets" + ) + + continue + link = context_llm_doc.link # Replace the citation in the current segment @@ -162,6 +186,7 @@ def extract_citations_from_stream( + curr_segment[end + length_to_add :] ) length_to_add += len(curr_segment) - prev_length + last_citation_end = end + length_to_add if last_citation_end > 0: diff --git a/backend/tests/unit/danswer/llm/answering/stream_processing/test_citation_processing.py b/backend/tests/unit/danswer/llm/answering/stream_processing/test_citation_processing.py index 87d8e0d32..473ccf245 100644 --- a/backend/tests/unit/danswer/llm/answering/stream_processing/test_citation_processing.py +++ b/backend/tests/unit/danswer/llm/answering/stream_processing/test_citation_processing.py @@ -257,6 +257,35 @@ def process_text( "[[1]](https://0.com)[[2]]()t]", ["doc_0", "doc_1"], ), + ( + "Citations with extraneous citations", + [ + "[[1]](https://0.com) Citation", + " at ", + "the beginning. ", + "[", + "3", + "]", + " In the mid", + "dle. At the end ", + "[", + "5", + "]", + ".", + ], + "[[1]](https://0.com) Citation at the beginning. [[2]]() In the middle. At the end [[3]](https://2.com).", + ["doc_0", "doc_1", "doc_2"], + ), + ( + "Citations with extraneous citations, split up", + [ + "[[1]](", + "https://0.com) Citation at ", + "the beginning. ", + ], + "[[1]](https://0.com) Citation at the beginning. ", + ["doc_0"], + ), ], ) def test_citation_extraction(