Fix citations + unit tests (#1760)

This commit is contained in:
pablodanswer 2024-07-10 10:05:20 -07:00 committed by GitHub
parent aa0f7abdac
commit 09a11b5e1a
No known key found for this signature in database
GPG Key ID: B5690EEEBB952194
11 changed files with 595 additions and 318 deletions

View File

@ -26,93 +26,154 @@ def extract_citations_from_stream(
doc_id_to_rank_map: DocumentIdOrderMapping, doc_id_to_rank_map: DocumentIdOrderMapping,
stop_stream: str | None = STOP_STREAM_PAT, stop_stream: str | None = STOP_STREAM_PAT,
) -> Iterator[DanswerAnswerPiece | CitationInfo]: ) -> Iterator[DanswerAnswerPiece | CitationInfo]:
"""
Key aspects:
1. Stream Processing:
- Processes tokens one by one, allowing for real-time handling of large texts.
2. Citation Detection:
- Uses regex to find citations in the format [number].
- Example: [1], [2], etc.
3. Citation Mapping:
- Maps detected citation numbers to actual document ranks using doc_id_to_rank_map.
- Example: [1] might become [3] if doc_id_to_rank_map maps it to 3.
4. Citation Formatting:
- Replaces citations with properly formatted versions.
- Adds links if available: [[1]](https://example.com)
- Handles cases where links are not available: [[1]]()
5. Duplicate Handling:
- Skips consecutive citations of the same document to avoid redundancy.
6. Output Generation:
- Yields DanswerAnswerPiece objects for regular text.
- Yields CitationInfo objects for each unique citation encountered.
7. Context Awareness:
- Uses context_docs to access document information for citations.
This function effectively processes a stream of text, identifies and reformats citations,
and provides both the processed text and citation information as output.
"""
order_mapping = doc_id_to_rank_map.order_mapping
llm_out = "" llm_out = ""
max_citation_num = len(context_docs) max_citation_num = len(context_docs)
citation_order = []
curr_segment = "" curr_segment = ""
prepend_bracket = False
cited_inds = set() cited_inds = set()
hold = "" hold = ""
raw_out = ""
current_citations: list[int] = []
past_cite_count = 0
for raw_token in tokens: for raw_token in tokens:
raw_out += raw_token
if stop_stream: if stop_stream:
next_hold = hold + raw_token next_hold = hold + raw_token
if stop_stream in next_hold: if stop_stream in next_hold:
break break
if next_hold == stop_stream[: len(next_hold)]: if next_hold == stop_stream[: len(next_hold)]:
hold = next_hold hold = next_hold
continue continue
token = next_hold token = next_hold
hold = "" hold = ""
else: else:
token = raw_token token = raw_token
# Special case of [1][ where ][ is a single token
# This is where the model attempts to do consecutive citations like [1][2]
if prepend_bracket:
curr_segment += "[" + curr_segment
prepend_bracket = False
curr_segment += token curr_segment += token
llm_out += token llm_out += token
citation_pattern = r"\[(\d+)\]"
citations_found = list(re.finditer(citation_pattern, curr_segment))
possible_citation_pattern = r"(\[\d*$)" # [1, [, etc possible_citation_pattern = r"(\[\d*$)" # [1, [, etc
possible_citation_found = re.search(possible_citation_pattern, curr_segment) possible_citation_found = re.search(possible_citation_pattern, curr_segment)
citation_pattern = r"\[(\d+)\]" # [1], [2] etc # `past_cite_count`: number of characters since past citation
citation_found = re.search(citation_pattern, curr_segment) # 5 to ensure a citation hasn't occured
if len(citations_found) == 0 and len(llm_out) - past_cite_count > 5:
current_citations = []
if citation_found and not in_code_block(llm_out): if citations_found and not in_code_block(llm_out):
numerical_value = int(citation_found.group(1)) last_citation_end = 0
if 1 <= numerical_value <= max_citation_num: length_to_add = 0
context_llm_doc = context_docs[ while len(citations_found) > 0:
numerical_value - 1 citation = citations_found.pop(0)
] # remove 1 index offset numerical_value = int(citation.group(1))
link = context_llm_doc.link if 1 <= numerical_value <= max_citation_num:
target_citation_num = doc_id_to_rank_map.order_mapping[ context_llm_doc = context_docs[numerical_value - 1]
context_llm_doc.document_id real_citation_num = order_mapping[context_llm_doc.document_id]
]
# Use the citation number for the document's rank in if real_citation_num not in citation_order:
# the search (or selected docs) results citation_order.append(real_citation_num)
curr_segment = re.sub(
rf"\[{numerical_value}\]", f"[{target_citation_num}]", curr_segment
)
if target_citation_num not in cited_inds: target_citation_num = citation_order.index(real_citation_num) + 1
cited_inds.add(target_citation_num)
yield CitationInfo( # Skip consecutive citations of the same work
citation_num=target_citation_num, if target_citation_num in current_citations:
document_id=context_llm_doc.document_id, start, end = citation.span()
real_start = length_to_add + start
diff = end - start
curr_segment = (
curr_segment[: length_to_add + start]
+ curr_segment[real_start + diff :]
)
length_to_add -= diff
continue
link = context_llm_doc.link
# Replace the citation in the current segment
start, end = citation.span()
curr_segment = (
curr_segment[: start + length_to_add]
+ f"[{target_citation_num}]"
+ curr_segment[end + length_to_add :]
) )
if link: past_cite_count = len(llm_out)
curr_segment = re.sub(r"\[", "[[", curr_segment, count=1) current_citations.append(target_citation_num)
curr_segment = re.sub("]", f"]]({link})", curr_segment, count=1)
# In case there's another open bracket like [1][, don't want to match this if target_citation_num not in cited_inds:
possible_citation_found = None cited_inds.add(target_citation_num)
yield CitationInfo(
citation_num=target_citation_num,
document_id=context_llm_doc.document_id,
)
# if we see "[", but haven't seen the right side, hold back - this may be a if link:
# citation that needs to be replaced with a link prev_length = len(curr_segment)
curr_segment = (
curr_segment[: start + length_to_add]
+ f"[[{target_citation_num}]]({link})"
+ curr_segment[end + length_to_add :]
)
length_to_add += len(curr_segment) - prev_length
else:
prev_length = len(curr_segment)
curr_segment = (
curr_segment[: start + length_to_add]
+ f"[[{target_citation_num}]]()"
+ curr_segment[end + length_to_add :]
)
length_to_add += len(curr_segment) - prev_length
last_citation_end = end + length_to_add
if last_citation_end > 0:
yield DanswerAnswerPiece(answer_piece=curr_segment[:last_citation_end])
curr_segment = curr_segment[last_citation_end:]
if possible_citation_found: if possible_citation_found:
continue continue
# Special case with back to back citations [1][2]
if curr_segment and curr_segment[-1] == "[":
curr_segment = curr_segment[:-1]
prepend_bracket = True
yield DanswerAnswerPiece(answer_piece=curr_segment) yield DanswerAnswerPiece(answer_piece=curr_segment)
curr_segment = "" curr_segment = ""
if curr_segment: if curr_segment:
if prepend_bracket: yield DanswerAnswerPiece(answer_piece=curr_segment)
yield DanswerAnswerPiece(answer_piece="[" + curr_segment)
else:
yield DanswerAnswerPiece(answer_piece=curr_segment)
def build_citation_processor( def build_citation_processor(

4
backend/pytest.ini Normal file
View File

@ -0,0 +1,4 @@
[pytest]
pythonpath = .
markers =
slow: marks tests as slow

View File

@ -3,6 +3,7 @@ celery-types==0.19.0
mypy-extensions==1.0.0 mypy-extensions==1.0.0
mypy==1.8.0 mypy==1.8.0
pre-commit==3.2.2 pre-commit==3.2.2
pytest==7.4.4
reorder-python-imports==3.9.0 reorder-python-imports==3.9.0
ruff==0.0.286 ruff==0.0.286
types-PyYAML==6.0.12.11 types-PyYAML==6.0.12.11

View File

@ -1,38 +0,0 @@
import unittest
class TestChatLlm(unittest.TestCase):
def test_citation_extraction(self) -> None:
pass # May fix these tests some day
"""
links: list[str | None] = [f"link_{i}" for i in range(1, 21)]
test_1 = "Something [1]"
res = "".join(list(extract_citations_from_stream(iter(test_1), links)))
self.assertEqual(res, "Something [[1]](link_1)")
test_2 = "Something [14]"
res = "".join(list(extract_citations_from_stream(iter(test_2), links)))
self.assertEqual(res, "Something [[14]](link_14)")
test_3 = "Something [14][15]"
res = "".join(list(extract_citations_from_stream(iter(test_3), links)))
self.assertEqual(res, "Something [[14]](link_14)[[15]](link_15)")
test_4 = ["Something ", "[", "3", "][", "4", "]."]
res = "".join(list(extract_citations_from_stream(iter(test_4), links)))
self.assertEqual(res, "Something [[3]](link_3)[[4]](link_4).")
test_5 = ["Something ", "[", "31", "][", "4", "]."]
res = "".join(list(extract_citations_from_stream(iter(test_5), links)))
self.assertEqual(res, "Something [31][[4]](link_4).")
links[3] = None
test_1 = "Something [2][4][5]"
res = "".join(list(extract_citations_from_stream(iter(test_1), links)))
self.assertEqual(res, "Something [[2]](link_2)[4][[5]](link_5)")
"""
if __name__ == "__main__":
unittest.main()

View File

@ -1,19 +1,13 @@
import pathlib import pathlib
import unittest
from danswer.file_processing.html_utils import parse_html_page_basic from danswer.file_processing.html_utils import parse_html_page_basic
class TestQAPostprocessing(unittest.TestCase): def test_parse_table() -> None:
def test_parse_table(self) -> None: dir_path = pathlib.Path(__file__).parent.resolve()
dir_path = pathlib.Path(__file__).parent.resolve() with open(f"{dir_path}/test_table.html", "r") as file:
with open(f"{dir_path}/test_table.html", "r") as file: content = file.read()
content = file.read()
parsed = parse_html_page_basic(content) parsed = parse_html_page_basic(content)
expected = "\n\thello\tthere\tgeneral\n\tkenobi\ta\tb\n\tc\td\te" expected = "\n\thello\tthere\tgeneral\n\tkenobi\ta\tb\n\tc\td\te"
self.assertIn(expected, parsed) assert expected in parsed
if __name__ == "__main__":
unittest.main()

View File

@ -1,36 +1,29 @@
import time import time
import unittest
from danswer.connectors.cross_connector_utils.rate_limit_wrapper import ( from danswer.connectors.cross_connector_utils.rate_limit_wrapper import (
rate_limit_builder, rate_limit_builder,
) )
class TestRateLimit(unittest.TestCase): def test_rate_limit_basic() -> None:
call_cnt = 0 call_cnt = 0
def test_rate_limit_basic(self) -> None: @rate_limit_builder(max_calls=2, period=5)
self.call_cnt = 0 def func() -> None:
nonlocal call_cnt
call_cnt += 1
@rate_limit_builder(max_calls=2, period=5) start = time.time()
def func() -> None:
self.call_cnt += 1
start = time.time() # Make calls that shouldn't be rate-limited
func()
func()
time_to_finish_non_ratelimited = time.time() - start
# make calls that shouldn't be rate-limited # Make a call which SHOULD be rate-limited
func() func()
func() time_to_finish_ratelimited = time.time() - start
time_to_finish_non_ratelimited = time.time() - start
# make a call which SHOULD be rate-limited assert call_cnt == 3
func() assert time_to_finish_non_ratelimited < 1
time_to_finish_ratelimited = time.time() - start assert time_to_finish_ratelimited > 5
self.assertEqual(self.call_cnt, 3)
self.assertLess(time_to_finish_non_ratelimited, 1)
self.assertGreater(time_to_finish_ratelimited, 5)
if __name__ == "__main__":
unittest.main()

View File

@ -1,5 +1,4 @@
import datetime import datetime
from unittest.mock import MagicMock
import pytest import pytest
from pytest_mock import MockFixture from pytest_mock import MockFixture
@ -100,7 +99,7 @@ def test_fetch_mails_from_gmail_empty(mocker: MockFixture) -> None:
"messages": [] "messages": []
} }
connector = GmailConnector() connector = GmailConnector()
connector.creds = MagicMock() connector.creds = mocker.Mock()
with pytest.raises(StopIteration): with pytest.raises(StopIteration):
next(connector.load_from_state()) next(connector.load_from_state())
@ -178,7 +177,7 @@ def test_fetch_mails_from_gmail(mocker: MockFixture) -> None:
} }
connector = GmailConnector() connector = GmailConnector()
connector.creds = MagicMock() connector.creds = mocker.Mock()
docs = next(connector.load_from_state()) docs = next(connector.load_from_state())
assert len(docs) == 1 assert len(docs) == 1
doc: Document = docs[0] doc: Document = docs[0]

View File

@ -1,7 +1,7 @@
from typing import Final from typing import Final
from unittest import mock
import pytest import pytest
from pytest_mock import MockFixture
from pywikibot.families.wikipedia_family import Family as WikipediaFamily # type: ignore[import-untyped] from pywikibot.families.wikipedia_family import Family as WikipediaFamily # type: ignore[import-untyped]
from pywikibot.family import Family # type: ignore[import-untyped] from pywikibot.family import Family # type: ignore[import-untyped]
@ -50,13 +50,11 @@ def test_family_class_dispatch_builtins(
@pytest.mark.parametrize("url, name", NON_BUILTIN_WIKIS) @pytest.mark.parametrize("url, name", NON_BUILTIN_WIKIS)
def test_family_class_dispatch_on_non_builtins_generates_new_class_fast( def test_family_class_dispatch_on_non_builtins_generates_new_class_fast(
url: str, name: str url: str, name: str, mocker: MockFixture
) -> None: ) -> None:
"""Test that using the family class dispatch function on an unknown url generates a new family class.""" """Test that using the family class dispatch function on an unknown url generates a new family class."""
with mock.patch.object( mock_generate_family_class = mocker.patch.object(family, "generate_family_class")
family, "generate_family_class" family.family_class_dispatch(url, name)
) as mock_generate_family_class:
family.family_class_dispatch(url, name)
mock_generate_family_class.assert_called_once_with(url, name) mock_generate_family_class.assert_called_once_with(url, name)

View File

@ -1,5 +1,6 @@
import textwrap import textwrap
import unittest
import pytest
from danswer.configs.constants import DocumentSource from danswer.configs.constants import DocumentSource
from danswer.llm.answering.stream_processing.quotes_processing import ( from danswer.llm.answering.stream_processing.quotes_processing import (
@ -11,194 +12,181 @@ from danswer.llm.answering.stream_processing.quotes_processing import (
from danswer.search.models import InferenceChunk from danswer.search.models import InferenceChunk
class TestQAPostprocessing(unittest.TestCase): def test_separate_answer_quotes() -> None:
def test_separate_answer_quotes(self) -> None: # Test case 1: Basic quote separation
test_answer = textwrap.dedent( test_answer = textwrap.dedent(
""" """
It seems many people love dogs It seems many people love dogs
Quote: A dog is a man's best friend Quote: A dog is a man's best friend
Quote: Air Bud was a movie about dogs and people loved it Quote: Air Bud was a movie about dogs and people loved it
""" """
).strip() ).strip()
answer, quotes = separate_answer_quotes(test_answer) answer, quotes = separate_answer_quotes(test_answer)
self.assertEqual(answer, "It seems many people love dogs") assert answer == "It seems many people love dogs"
self.assertEqual(quotes[0], "A dog is a man's best friend") # type: ignore assert isinstance(quotes, list)
self.assertEqual( assert quotes[0] == "A dog is a man's best friend"
quotes[1], "Air Bud was a movie about dogs and people loved it" # type: ignore assert quotes[1] == "Air Bud was a movie about dogs and people loved it"
)
# Lowercase should be allowed # Test case 2: Lowercase 'quote' allowed
test_answer = textwrap.dedent( test_answer = textwrap.dedent(
""" """
It seems many people love dogs It seems many people love dogs
quote: A dog is a man's best friend quote: A dog is a man's best friend
Quote: Air Bud was a movie about dogs and people loved it Quote: Air Bud was a movie about dogs and people loved it
""" """
).strip() ).strip()
answer, quotes = separate_answer_quotes(test_answer) answer, quotes = separate_answer_quotes(test_answer)
self.assertEqual(answer, "It seems many people love dogs") assert answer == "It seems many people love dogs"
self.assertEqual(quotes[0], "A dog is a man's best friend") # type: ignore assert isinstance(quotes, list)
self.assertEqual( assert quotes[0] == "A dog is a man's best friend"
quotes[1], "Air Bud was a movie about dogs and people loved it" # type: ignore assert quotes[1] == "Air Bud was a movie about dogs and people loved it"
)
# No Answer # Test case 3: No Answer
test_answer = textwrap.dedent( test_answer = textwrap.dedent(
""" """
Quote: This one has no answer Quote: This one has no answer
""" """
).strip() ).strip()
answer, quotes = separate_answer_quotes(test_answer) answer, quotes = separate_answer_quotes(test_answer)
self.assertIsNone(answer) assert answer is None
self.assertIsNone(quotes) assert quotes is None
# Multiline Quote # Test case 4: Multiline Quote
test_answer = textwrap.dedent( test_answer = textwrap.dedent(
""" """
It seems many people love dogs It seems many people love dogs
quote: A well known saying is: quote: A well known saying is:
A dog is a man's best friend A dog is a man's best friend
Quote: Air Bud was a movie about dogs and people loved it Quote: Air Bud was a movie about dogs and people loved it
""" """
).strip() ).strip()
answer, quotes = separate_answer_quotes(test_answer) answer, quotes = separate_answer_quotes(test_answer)
self.assertEqual(answer, "It seems many people love dogs") assert answer == "It seems many people love dogs"
self.assertEqual( assert isinstance(quotes, list)
quotes[0], "A well known saying is:\nA dog is a man's best friend" # type: ignore assert quotes[0] == "A well known saying is:\nA dog is a man's best friend"
) assert quotes[1] == "Air Bud was a movie about dogs and people loved it"
self.assertEqual(
quotes[1], "Air Bud was a movie about dogs and people loved it" # type: ignore
)
# Random patterns not picked up # Test case 5: Random patterns not picked up
test_answer = textwrap.dedent( test_answer = textwrap.dedent(
""" """
It seems many people love quote: dogs It seems many people love quote: dogs
quote: Quote: A well known saying is: quote: Quote: A well known saying is:
A dog is a man's best friend A dog is a man's best friend
Quote: Answer: Air Bud was a movie about dogs and quote: people loved it Quote: Answer: Air Bud was a movie about dogs and quote: people loved it
""" """
).strip() ).strip()
answer, quotes = separate_answer_quotes(test_answer) answer, quotes = separate_answer_quotes(test_answer)
self.assertEqual(answer, "It seems many people love quote: dogs") assert answer == "It seems many people love quote: dogs"
self.assertEqual( assert isinstance(quotes, list)
quotes[0], "Quote: A well known saying is:\nA dog is a man's best friend" # type: ignore assert quotes[0] == "Quote: A well known saying is:\nA dog is a man's best friend"
) assert (
self.assertEqual( quotes[1] == "Answer: Air Bud was a movie about dogs and quote: people loved it"
quotes[1], # type: ignore
"Answer: Air Bud was a movie about dogs and quote: people loved it",
)
@unittest.skip(
"Using fuzzy match is too slow anyway, doesn't matter if it's broken"
) )
def test_fuzzy_match_quotes_to_docs(self) -> None:
chunk_0_text = textwrap.dedent(
"""
Here's a doc with some LINK embedded in the text
THIS SECTION IS A LINK
Some more text
"""
).strip()
chunk_1_text = textwrap.dedent(
"""
Some completely different text here
ANOTHER LINK embedded in this text
ending in a DIFFERENT-LINK
"""
).strip()
test_chunk_0 = InferenceChunk(
document_id="test doc 0",
source_type=DocumentSource.FILE,
chunk_id=0,
content=chunk_0_text,
source_links={
0: "doc 0 base",
23: "first line link",
49: "second line link",
},
blurb="anything",
semantic_identifier="anything",
section_continuation=False,
recency_bias=1,
boost=0,
hidden=False,
score=1,
metadata={},
match_highlights=[],
updated_at=None,
)
test_chunk_1 = InferenceChunk(
document_id="test doc 1",
source_type=DocumentSource.FILE,
chunk_id=0,
content=chunk_1_text,
source_links={0: "doc 1 base", 36: "2nd line link", 82: "last link"},
blurb="whatever",
semantic_identifier="whatever",
section_continuation=False,
recency_bias=1,
boost=0,
hidden=False,
score=1,
metadata={},
match_highlights=[],
updated_at=None,
)
test_quotes = [
"a doc with some", # Basic case
"a doc with some LINK", # Should take the start of quote, even if a link is in it
"a doc with some \nLINK", # Requires a newline deletion fuzzy match
"a doc with some link", # Capitalization insensitive
"embedded in this text", # Fuzzy match to first doc
"SECTION IS A LINK", # Match exact link
"some more text", # Match the end, after every link offset
"different taxt", # Substitution
"embedded in this texts", # Cannot fuzzy match to first doc, fuzzy match to second doc
"DIFFERENT-LINK", # Exact link match at the end
"Some complitali", # Too many edits, shouldn't match anything
]
results = match_quotes_to_docs(
test_quotes, [test_chunk_0, test_chunk_1], fuzzy_search=True
)
self.assertEqual(
results,
{
"a doc with some": {"document": "test doc 0", "link": "doc 0 base"},
"a doc with some LINK": {
"document": "test doc 0",
"link": "doc 0 base",
},
"a doc with some \nLINK": {
"document": "test doc 0",
"link": "doc 0 base",
},
"a doc with some link": {
"document": "test doc 0",
"link": "doc 0 base",
},
"embedded in this text": {
"document": "test doc 0",
"link": "first line link",
},
"SECTION IS A LINK": {
"document": "test doc 0",
"link": "second line link",
},
"some more text": {
"document": "test doc 0",
"link": "second line link",
},
"different taxt": {"document": "test doc 1", "link": "doc 1 base"},
"embedded in this texts": {
"document": "test doc 1",
"link": "2nd line link",
},
"DIFFERENT-LINK": {"document": "test doc 1", "link": "last link"},
},
)
if __name__ == "__main__": @pytest.mark.skip(
unittest.main() reason="Using fuzzy match is too slow anyway, doesn't matter if it's broken"
)
def test_fuzzy_match_quotes_to_docs() -> None:
chunk_0_text = textwrap.dedent(
"""
Here's a doc with some LINK embedded in the text
THIS SECTION IS A LINK
Some more text
"""
).strip()
chunk_1_text = textwrap.dedent(
"""
Some completely different text here
ANOTHER LINK embedded in this text
ending in a DIFFERENT-LINK
"""
).strip()
test_chunk_0 = InferenceChunk(
document_id="test doc 0",
source_type=DocumentSource.FILE,
chunk_id=0,
content=chunk_0_text,
source_links={
0: "doc 0 base",
23: "first line link",
49: "second line link",
},
blurb="anything",
semantic_identifier="anything",
section_continuation=False,
recency_bias=1,
boost=0,
hidden=False,
score=1,
metadata={},
match_highlights=[],
updated_at=None,
)
test_chunk_1 = InferenceChunk(
document_id="test doc 1",
source_type=DocumentSource.FILE,
chunk_id=0,
content=chunk_1_text,
source_links={0: "doc 1 base", 36: "2nd line link", 82: "last link"},
blurb="whatever",
semantic_identifier="whatever",
section_continuation=False,
recency_bias=1,
boost=0,
hidden=False,
score=1,
metadata={},
match_highlights=[],
updated_at=None,
)
test_quotes = [
"a doc with some", # Basic case
"a doc with some LINK", # Should take the start of quote, even if a link is in it
"a doc with some \nLINK", # Requires a newline deletion fuzzy match
"a doc with some link", # Capitalization insensitive
"embedded in this text", # Fuzzy match to first doc
"SECTION IS A LINK", # Match exact link
"some more text", # Match the end, after every link offset
"different taxt", # Substitution
"embedded in this texts", # Cannot fuzzy match to first doc, fuzzy match to second doc
"DIFFERENT-LINK", # Exact link match at the end
"Some complitali", # Too many edits, shouldn't match anything
]
results = match_quotes_to_docs(
test_quotes, [test_chunk_0, test_chunk_1], fuzzy_search=True
)
assert results == {
"a doc with some": {"document": "test doc 0", "link": "doc 0 base"},
"a doc with some LINK": {
"document": "test doc 0",
"link": "doc 0 base",
},
"a doc with some \nLINK": {
"document": "test doc 0",
"link": "doc 0 base",
},
"a doc with some link": {
"document": "test doc 0",
"link": "doc 0 base",
},
"embedded in this text": {
"document": "test doc 0",
"link": "first line link",
},
"SECTION IS A LINK": {
"document": "test doc 0",
"link": "second line link",
},
"some more text": {
"document": "test doc 0",
"link": "second line link",
},
"different taxt": {"document": "test doc 1", "link": "doc 1 base"},
"embedded in this texts": {
"document": "test doc 1",
"link": "2nd line link",
},
"DIFFERENT-LINK": {"document": "test doc 1", "link": "last link"},
}

View File

@ -0,0 +1,277 @@
from datetime import datetime
import pytest
from danswer.chat.models import CitationInfo
from danswer.chat.models import DanswerAnswerPiece
from danswer.chat.models import LlmDoc
from danswer.configs.constants import DocumentSource
from danswer.llm.answering.stream_processing.citation_processing import (
extract_citations_from_stream,
)
from danswer.llm.answering.stream_processing.utils import DocumentIdOrderMapping
"""
This module contains tests for the citation extraction functionality in Danswer.
The tests focus on the `extract_citations_from_stream` function, which processes
a stream of tokens and extracts citations, replacing them with properly formatted
versions including links where available.
Key components:
- mock_docs: A list of mock LlmDoc objects used for testing.
- mock_doc_mapping: A dictionary mapping document IDs to their ranks.
- process_text: A helper function that simulates the citation extraction process.
- test_citation_extraction: A parametrized test function covering various citation scenarios.
To add new test cases:
1. Add a new tuple to the @pytest.mark.parametrize decorator of test_citation_extraction.
2. Each tuple should contain:
- A descriptive test name (string)
- Input tokens (list of strings)
- Expected output text (string)
- Expected citations (list of document IDs)
"""
mock_docs = [
LlmDoc(
document_id=f"doc_{int(id/2)}",
content="Document is a doc",
blurb=f"Document #{id}",
semantic_identifier=f"Doc {id}",
source_type=DocumentSource.WEB,
metadata={},
updated_at=datetime.now(),
link=f"https://{int(id/2)}.com" if int(id / 2) % 2 == 0 else None,
source_links={0: "https://mintlify.com/docs/settings/broken-links"},
)
for id in range(10)
]
mock_doc_mapping = {
"doc_0": 1,
"doc_1": 2,
"doc_2": 3,
"doc_3": 4,
"doc_4": 5,
"doc_5": 6,
}
@pytest.fixture
def mock_data() -> tuple[list[LlmDoc], dict[str, int]]:
return mock_docs, mock_doc_mapping
def process_text(
tokens: list[str], mock_data: tuple[list[LlmDoc], dict[str, int]]
) -> tuple[str, list[CitationInfo]]:
mock_docs, mock_doc_id_to_rank_map = mock_data
mapping = DocumentIdOrderMapping(order_mapping=mock_doc_id_to_rank_map)
result = list(
extract_citations_from_stream(
tokens=iter(tokens),
context_docs=mock_docs,
doc_id_to_rank_map=mapping,
stop_stream=None,
)
)
final_answer_text = ""
citations = []
for piece in result:
if isinstance(piece, DanswerAnswerPiece):
final_answer_text += piece.answer_piece or ""
elif isinstance(piece, CitationInfo):
citations.append(piece)
return final_answer_text, citations
@pytest.mark.parametrize(
"test_name, input_tokens, expected_text, expected_citations",
[
(
"Single citation",
["Gro", "wth! [", "1", "]", "."],
"Growth! [[1]](https://0.com).",
["doc_0"],
),
(
"Repeated citations",
["Test! ", "[", "1", "]", ". And so", "me more ", "[", "2", "]", "."],
"Test! [[1]](https://0.com). And some more [[1]](https://0.com).",
["doc_0"],
),
(
"Citations at sentence boundaries",
[
"Citation at the ",
"end of a sen",
"tence.",
"[",
"2",
"]",
" Another sen",
"tence.",
"[",
"4",
"]",
],
"Citation at the end of a sentence.[[1]](https://0.com) Another sentence.[[2]]()",
["doc_0", "doc_1"],
),
(
"Citations at beginning, middle, and end",
[
"[",
"1",
"]",
" Citation at ",
"the beginning. ",
"[",
"3",
"]",
" In the mid",
"dle. At the end ",
"[",
"5",
"]",
".",
],
"[[1]](https://0.com) Citation at the beginning. [[2]]() In the middle. At the end [[3]](https://2.com).",
["doc_0", "doc_1", "doc_2"],
),
(
"Mixed valid and invalid citations",
[
"Mixed valid and in",
"valid citations ",
"[",
"1",
"]",
"[",
"99",
"]",
"[",
"3",
"]",
"[",
"100",
"]",
"[",
"5",
"]",
".",
],
"Mixed valid and invalid citations [[1]](https://0.com)[99][[2]]()[100][[3]](https://2.com).",
["doc_0", "doc_1", "doc_2"],
),
(
"Hardest!",
[
"Multiple cit",
"ations in one ",
"sentence [",
"1",
"]",
"[",
"4",
"]",
"[",
"5",
"]",
". ",
],
"Multiple citations in one sentence [[1]](https://0.com)[[2]]()[[3]](https://2.com).",
["doc_0", "doc_1", "doc_2"],
),
(
"Repeated citations with text",
["[", "1", "]", "Aasf", "asda", "sff ", "[", "1", "]", " ."],
"[[1]](https://0.com)Aasfasdasff [[1]](https://0.com) .",
["doc_0"],
),
(
"Consecutive identical citations!",
[
"Citations [",
"1",
"]",
"[",
"1]",
"",
"[2",
"",
"]",
". ",
],
"Citations [[1]](https://0.com).",
["doc_0"],
),
(
"Consecutive identical citations!",
[
"test [1]tt[1]t",
"",
],
"test [[1]](https://0.com)ttt",
["doc_0"],
),
(
"Consecutive identical citations!",
[
"test [1]t[1]t[1]",
"",
],
"test [[1]](https://0.com)tt",
["doc_0"],
),
(
"Repeated citations with text",
["[", "1", "]", "Aasf", "asda", "sff ", "[", "1", "]", " ."],
"[[1]](https://0.com)Aasfasdasff [[1]](https://0.com) .",
["doc_0"],
),
(
"Repeated citations with text",
["[1][", "1", "]t", "[2]"],
"[[1]](https://0.com)t",
["doc_0"],
),
(
"Repeated citations with text",
["[1][", "1", "]t]", "[2]"],
"[[1]](https://0.com)t]",
["doc_0"],
),
(
"Repeated citations with text",
["[1][", "3", "]t]", "[2]"],
"[[1]](https://0.com)[[2]]()t]",
["doc_0", "doc_1"],
),
(
"Repeated citations with text",
["[1", "][", "3", "]t]", "[2]"],
"[[1]](https://0.com)[[2]]()t]",
["doc_0", "doc_1"],
),
],
)
def test_citation_extraction(
mock_data: tuple[list[LlmDoc], dict[str, int]],
test_name: str,
input_tokens: list[str],
expected_text: str,
expected_citations: list[str],
) -> None:
final_answer_text, citations = process_text(input_tokens, mock_data)
assert (
final_answer_text.strip() == expected_text.strip()
), f"Test '{test_name}' failed: Final answer text does not match expected output."
assert [
citation.document_id for citation in citations
] == expected_citations, (
f"Test '{test_name}' failed: Citations do not match expected output."
)

View File

@ -1 +1 @@
f1f2 1 1718910083.03085 wikipedia:en f1f2 1 1718910083.03085 wikipedia:en