mirror of
https://github.com/danswer-ai/danswer.git
synced 2025-05-03 00:10:24 +02:00
Fix citations + unit tests (#1760)
This commit is contained in:
parent
aa0f7abdac
commit
09a11b5e1a
@ -26,93 +26,154 @@ def extract_citations_from_stream(
|
||||
doc_id_to_rank_map: DocumentIdOrderMapping,
|
||||
stop_stream: str | None = STOP_STREAM_PAT,
|
||||
) -> Iterator[DanswerAnswerPiece | CitationInfo]:
|
||||
"""
|
||||
Key aspects:
|
||||
|
||||
1. Stream Processing:
|
||||
- Processes tokens one by one, allowing for real-time handling of large texts.
|
||||
|
||||
2. Citation Detection:
|
||||
- Uses regex to find citations in the format [number].
|
||||
- Example: [1], [2], etc.
|
||||
|
||||
3. Citation Mapping:
|
||||
- Maps detected citation numbers to actual document ranks using doc_id_to_rank_map.
|
||||
- Example: [1] might become [3] if doc_id_to_rank_map maps it to 3.
|
||||
|
||||
4. Citation Formatting:
|
||||
- Replaces citations with properly formatted versions.
|
||||
- Adds links if available: [[1]](https://example.com)
|
||||
- Handles cases where links are not available: [[1]]()
|
||||
|
||||
5. Duplicate Handling:
|
||||
- Skips consecutive citations of the same document to avoid redundancy.
|
||||
|
||||
6. Output Generation:
|
||||
- Yields DanswerAnswerPiece objects for regular text.
|
||||
- Yields CitationInfo objects for each unique citation encountered.
|
||||
|
||||
7. Context Awareness:
|
||||
- Uses context_docs to access document information for citations.
|
||||
|
||||
This function effectively processes a stream of text, identifies and reformats citations,
|
||||
and provides both the processed text and citation information as output.
|
||||
"""
|
||||
order_mapping = doc_id_to_rank_map.order_mapping
|
||||
llm_out = ""
|
||||
max_citation_num = len(context_docs)
|
||||
citation_order = []
|
||||
curr_segment = ""
|
||||
prepend_bracket = False
|
||||
cited_inds = set()
|
||||
hold = ""
|
||||
|
||||
raw_out = ""
|
||||
current_citations: list[int] = []
|
||||
past_cite_count = 0
|
||||
for raw_token in tokens:
|
||||
raw_out += raw_token
|
||||
if stop_stream:
|
||||
next_hold = hold + raw_token
|
||||
|
||||
if stop_stream in next_hold:
|
||||
break
|
||||
|
||||
if next_hold == stop_stream[: len(next_hold)]:
|
||||
hold = next_hold
|
||||
continue
|
||||
|
||||
token = next_hold
|
||||
hold = ""
|
||||
else:
|
||||
token = raw_token
|
||||
|
||||
# Special case of [1][ where ][ is a single token
|
||||
# This is where the model attempts to do consecutive citations like [1][2]
|
||||
if prepend_bracket:
|
||||
curr_segment += "[" + curr_segment
|
||||
prepend_bracket = False
|
||||
|
||||
curr_segment += token
|
||||
llm_out += token
|
||||
|
||||
citation_pattern = r"\[(\d+)\]"
|
||||
|
||||
citations_found = list(re.finditer(citation_pattern, curr_segment))
|
||||
possible_citation_pattern = r"(\[\d*$)" # [1, [, etc
|
||||
possible_citation_found = re.search(possible_citation_pattern, curr_segment)
|
||||
|
||||
citation_pattern = r"\[(\d+)\]" # [1], [2] etc
|
||||
citation_found = re.search(citation_pattern, curr_segment)
|
||||
# `past_cite_count`: number of characters since past citation
|
||||
# 5 to ensure a citation hasn't occured
|
||||
if len(citations_found) == 0 and len(llm_out) - past_cite_count > 5:
|
||||
current_citations = []
|
||||
|
||||
if citation_found and not in_code_block(llm_out):
|
||||
numerical_value = int(citation_found.group(1))
|
||||
if 1 <= numerical_value <= max_citation_num:
|
||||
context_llm_doc = context_docs[
|
||||
numerical_value - 1
|
||||
] # remove 1 index offset
|
||||
if citations_found and not in_code_block(llm_out):
|
||||
last_citation_end = 0
|
||||
length_to_add = 0
|
||||
while len(citations_found) > 0:
|
||||
citation = citations_found.pop(0)
|
||||
numerical_value = int(citation.group(1))
|
||||
|
||||
link = context_llm_doc.link
|
||||
target_citation_num = doc_id_to_rank_map.order_mapping[
|
||||
context_llm_doc.document_id
|
||||
]
|
||||
if 1 <= numerical_value <= max_citation_num:
|
||||
context_llm_doc = context_docs[numerical_value - 1]
|
||||
real_citation_num = order_mapping[context_llm_doc.document_id]
|
||||
|
||||
# Use the citation number for the document's rank in
|
||||
# the search (or selected docs) results
|
||||
curr_segment = re.sub(
|
||||
rf"\[{numerical_value}\]", f"[{target_citation_num}]", curr_segment
|
||||
)
|
||||
if real_citation_num not in citation_order:
|
||||
citation_order.append(real_citation_num)
|
||||
|
||||
if target_citation_num not in cited_inds:
|
||||
cited_inds.add(target_citation_num)
|
||||
yield CitationInfo(
|
||||
citation_num=target_citation_num,
|
||||
document_id=context_llm_doc.document_id,
|
||||
target_citation_num = citation_order.index(real_citation_num) + 1
|
||||
|
||||
# Skip consecutive citations of the same work
|
||||
if target_citation_num in current_citations:
|
||||
start, end = citation.span()
|
||||
real_start = length_to_add + start
|
||||
diff = end - start
|
||||
curr_segment = (
|
||||
curr_segment[: length_to_add + start]
|
||||
+ curr_segment[real_start + diff :]
|
||||
)
|
||||
length_to_add -= diff
|
||||
continue
|
||||
|
||||
link = context_llm_doc.link
|
||||
|
||||
# Replace the citation in the current segment
|
||||
start, end = citation.span()
|
||||
curr_segment = (
|
||||
curr_segment[: start + length_to_add]
|
||||
+ f"[{target_citation_num}]"
|
||||
+ curr_segment[end + length_to_add :]
|
||||
)
|
||||
|
||||
if link:
|
||||
curr_segment = re.sub(r"\[", "[[", curr_segment, count=1)
|
||||
curr_segment = re.sub("]", f"]]({link})", curr_segment, count=1)
|
||||
past_cite_count = len(llm_out)
|
||||
current_citations.append(target_citation_num)
|
||||
|
||||
# In case there's another open bracket like [1][, don't want to match this
|
||||
possible_citation_found = None
|
||||
if target_citation_num not in cited_inds:
|
||||
cited_inds.add(target_citation_num)
|
||||
yield CitationInfo(
|
||||
citation_num=target_citation_num,
|
||||
document_id=context_llm_doc.document_id,
|
||||
)
|
||||
|
||||
# if we see "[", but haven't seen the right side, hold back - this may be a
|
||||
# citation that needs to be replaced with a link
|
||||
if link:
|
||||
prev_length = len(curr_segment)
|
||||
curr_segment = (
|
||||
curr_segment[: start + length_to_add]
|
||||
+ f"[[{target_citation_num}]]({link})"
|
||||
+ curr_segment[end + length_to_add :]
|
||||
)
|
||||
length_to_add += len(curr_segment) - prev_length
|
||||
|
||||
else:
|
||||
prev_length = len(curr_segment)
|
||||
curr_segment = (
|
||||
curr_segment[: start + length_to_add]
|
||||
+ f"[[{target_citation_num}]]()"
|
||||
+ curr_segment[end + length_to_add :]
|
||||
)
|
||||
length_to_add += len(curr_segment) - prev_length
|
||||
last_citation_end = end + length_to_add
|
||||
|
||||
if last_citation_end > 0:
|
||||
yield DanswerAnswerPiece(answer_piece=curr_segment[:last_citation_end])
|
||||
curr_segment = curr_segment[last_citation_end:]
|
||||
if possible_citation_found:
|
||||
continue
|
||||
|
||||
# Special case with back to back citations [1][2]
|
||||
if curr_segment and curr_segment[-1] == "[":
|
||||
curr_segment = curr_segment[:-1]
|
||||
prepend_bracket = True
|
||||
|
||||
yield DanswerAnswerPiece(answer_piece=curr_segment)
|
||||
curr_segment = ""
|
||||
|
||||
if curr_segment:
|
||||
if prepend_bracket:
|
||||
yield DanswerAnswerPiece(answer_piece="[" + curr_segment)
|
||||
else:
|
||||
yield DanswerAnswerPiece(answer_piece=curr_segment)
|
||||
yield DanswerAnswerPiece(answer_piece=curr_segment)
|
||||
|
||||
|
||||
def build_citation_processor(
|
||||
|
4
backend/pytest.ini
Normal file
4
backend/pytest.ini
Normal file
@ -0,0 +1,4 @@
|
||||
[pytest]
|
||||
pythonpath = .
|
||||
markers =
|
||||
slow: marks tests as slow
|
@ -3,6 +3,7 @@ celery-types==0.19.0
|
||||
mypy-extensions==1.0.0
|
||||
mypy==1.8.0
|
||||
pre-commit==3.2.2
|
||||
pytest==7.4.4
|
||||
reorder-python-imports==3.9.0
|
||||
ruff==0.0.286
|
||||
types-PyYAML==6.0.12.11
|
||||
|
@ -1,38 +0,0 @@
|
||||
import unittest
|
||||
|
||||
|
||||
class TestChatLlm(unittest.TestCase):
|
||||
def test_citation_extraction(self) -> None:
|
||||
pass # May fix these tests some day
|
||||
"""
|
||||
links: list[str | None] = [f"link_{i}" for i in range(1, 21)]
|
||||
|
||||
test_1 = "Something [1]"
|
||||
res = "".join(list(extract_citations_from_stream(iter(test_1), links)))
|
||||
self.assertEqual(res, "Something [[1]](link_1)")
|
||||
|
||||
test_2 = "Something [14]"
|
||||
res = "".join(list(extract_citations_from_stream(iter(test_2), links)))
|
||||
self.assertEqual(res, "Something [[14]](link_14)")
|
||||
|
||||
test_3 = "Something [14][15]"
|
||||
res = "".join(list(extract_citations_from_stream(iter(test_3), links)))
|
||||
self.assertEqual(res, "Something [[14]](link_14)[[15]](link_15)")
|
||||
|
||||
test_4 = ["Something ", "[", "3", "][", "4", "]."]
|
||||
res = "".join(list(extract_citations_from_stream(iter(test_4), links)))
|
||||
self.assertEqual(res, "Something [[3]](link_3)[[4]](link_4).")
|
||||
|
||||
test_5 = ["Something ", "[", "31", "][", "4", "]."]
|
||||
res = "".join(list(extract_citations_from_stream(iter(test_5), links)))
|
||||
self.assertEqual(res, "Something [31][[4]](link_4).")
|
||||
|
||||
links[3] = None
|
||||
test_1 = "Something [2][4][5]"
|
||||
res = "".join(list(extract_citations_from_stream(iter(test_1), links)))
|
||||
self.assertEqual(res, "Something [[2]](link_2)[4][[5]](link_5)")
|
||||
"""
|
||||
|
||||
|
||||
if __name__ == "__main__":
|
||||
unittest.main()
|
@ -1,19 +1,13 @@
|
||||
import pathlib
|
||||
import unittest
|
||||
|
||||
from danswer.file_processing.html_utils import parse_html_page_basic
|
||||
|
||||
|
||||
class TestQAPostprocessing(unittest.TestCase):
|
||||
def test_parse_table(self) -> None:
|
||||
dir_path = pathlib.Path(__file__).parent.resolve()
|
||||
with open(f"{dir_path}/test_table.html", "r") as file:
|
||||
content = file.read()
|
||||
def test_parse_table() -> None:
|
||||
dir_path = pathlib.Path(__file__).parent.resolve()
|
||||
with open(f"{dir_path}/test_table.html", "r") as file:
|
||||
content = file.read()
|
||||
|
||||
parsed = parse_html_page_basic(content)
|
||||
expected = "\n\thello\tthere\tgeneral\n\tkenobi\ta\tb\n\tc\td\te"
|
||||
self.assertIn(expected, parsed)
|
||||
|
||||
|
||||
if __name__ == "__main__":
|
||||
unittest.main()
|
||||
parsed = parse_html_page_basic(content)
|
||||
expected = "\n\thello\tthere\tgeneral\n\tkenobi\ta\tb\n\tc\td\te"
|
||||
assert expected in parsed
|
||||
|
@ -1,36 +1,29 @@
|
||||
import time
|
||||
import unittest
|
||||
|
||||
from danswer.connectors.cross_connector_utils.rate_limit_wrapper import (
|
||||
rate_limit_builder,
|
||||
)
|
||||
|
||||
|
||||
class TestRateLimit(unittest.TestCase):
|
||||
def test_rate_limit_basic() -> None:
|
||||
call_cnt = 0
|
||||
|
||||
def test_rate_limit_basic(self) -> None:
|
||||
self.call_cnt = 0
|
||||
@rate_limit_builder(max_calls=2, period=5)
|
||||
def func() -> None:
|
||||
nonlocal call_cnt
|
||||
call_cnt += 1
|
||||
|
||||
@rate_limit_builder(max_calls=2, period=5)
|
||||
def func() -> None:
|
||||
self.call_cnt += 1
|
||||
start = time.time()
|
||||
|
||||
start = time.time()
|
||||
# Make calls that shouldn't be rate-limited
|
||||
func()
|
||||
func()
|
||||
time_to_finish_non_ratelimited = time.time() - start
|
||||
|
||||
# make calls that shouldn't be rate-limited
|
||||
func()
|
||||
func()
|
||||
time_to_finish_non_ratelimited = time.time() - start
|
||||
# Make a call which SHOULD be rate-limited
|
||||
func()
|
||||
time_to_finish_ratelimited = time.time() - start
|
||||
|
||||
# make a call which SHOULD be rate-limited
|
||||
func()
|
||||
time_to_finish_ratelimited = time.time() - start
|
||||
|
||||
self.assertEqual(self.call_cnt, 3)
|
||||
self.assertLess(time_to_finish_non_ratelimited, 1)
|
||||
self.assertGreater(time_to_finish_ratelimited, 5)
|
||||
|
||||
|
||||
if __name__ == "__main__":
|
||||
unittest.main()
|
||||
assert call_cnt == 3
|
||||
assert time_to_finish_non_ratelimited < 1
|
||||
assert time_to_finish_ratelimited > 5
|
||||
|
@ -1,5 +1,4 @@
|
||||
import datetime
|
||||
from unittest.mock import MagicMock
|
||||
|
||||
import pytest
|
||||
from pytest_mock import MockFixture
|
||||
@ -100,7 +99,7 @@ def test_fetch_mails_from_gmail_empty(mocker: MockFixture) -> None:
|
||||
"messages": []
|
||||
}
|
||||
connector = GmailConnector()
|
||||
connector.creds = MagicMock()
|
||||
connector.creds = mocker.Mock()
|
||||
with pytest.raises(StopIteration):
|
||||
next(connector.load_from_state())
|
||||
|
||||
@ -178,7 +177,7 @@ def test_fetch_mails_from_gmail(mocker: MockFixture) -> None:
|
||||
}
|
||||
|
||||
connector = GmailConnector()
|
||||
connector.creds = MagicMock()
|
||||
connector.creds = mocker.Mock()
|
||||
docs = next(connector.load_from_state())
|
||||
assert len(docs) == 1
|
||||
doc: Document = docs[0]
|
||||
|
@ -1,7 +1,7 @@
|
||||
from typing import Final
|
||||
from unittest import mock
|
||||
|
||||
import pytest
|
||||
from pytest_mock import MockFixture
|
||||
from pywikibot.families.wikipedia_family import Family as WikipediaFamily # type: ignore[import-untyped]
|
||||
from pywikibot.family import Family # type: ignore[import-untyped]
|
||||
|
||||
@ -50,13 +50,11 @@ def test_family_class_dispatch_builtins(
|
||||
|
||||
@pytest.mark.parametrize("url, name", NON_BUILTIN_WIKIS)
|
||||
def test_family_class_dispatch_on_non_builtins_generates_new_class_fast(
|
||||
url: str, name: str
|
||||
url: str, name: str, mocker: MockFixture
|
||||
) -> None:
|
||||
"""Test that using the family class dispatch function on an unknown url generates a new family class."""
|
||||
with mock.patch.object(
|
||||
family, "generate_family_class"
|
||||
) as mock_generate_family_class:
|
||||
family.family_class_dispatch(url, name)
|
||||
mock_generate_family_class = mocker.patch.object(family, "generate_family_class")
|
||||
family.family_class_dispatch(url, name)
|
||||
mock_generate_family_class.assert_called_once_with(url, name)
|
||||
|
||||
|
||||
|
@ -1,5 +1,6 @@
|
||||
import textwrap
|
||||
import unittest
|
||||
|
||||
import pytest
|
||||
|
||||
from danswer.configs.constants import DocumentSource
|
||||
from danswer.llm.answering.stream_processing.quotes_processing import (
|
||||
@ -11,194 +12,181 @@ from danswer.llm.answering.stream_processing.quotes_processing import (
|
||||
from danswer.search.models import InferenceChunk
|
||||
|
||||
|
||||
class TestQAPostprocessing(unittest.TestCase):
|
||||
def test_separate_answer_quotes(self) -> None:
|
||||
test_answer = textwrap.dedent(
|
||||
"""
|
||||
It seems many people love dogs
|
||||
Quote: A dog is a man's best friend
|
||||
Quote: Air Bud was a movie about dogs and people loved it
|
||||
"""
|
||||
).strip()
|
||||
answer, quotes = separate_answer_quotes(test_answer)
|
||||
self.assertEqual(answer, "It seems many people love dogs")
|
||||
self.assertEqual(quotes[0], "A dog is a man's best friend") # type: ignore
|
||||
self.assertEqual(
|
||||
quotes[1], "Air Bud was a movie about dogs and people loved it" # type: ignore
|
||||
)
|
||||
def test_separate_answer_quotes() -> None:
|
||||
# Test case 1: Basic quote separation
|
||||
test_answer = textwrap.dedent(
|
||||
"""
|
||||
It seems many people love dogs
|
||||
Quote: A dog is a man's best friend
|
||||
Quote: Air Bud was a movie about dogs and people loved it
|
||||
"""
|
||||
).strip()
|
||||
answer, quotes = separate_answer_quotes(test_answer)
|
||||
assert answer == "It seems many people love dogs"
|
||||
assert isinstance(quotes, list)
|
||||
assert quotes[0] == "A dog is a man's best friend"
|
||||
assert quotes[1] == "Air Bud was a movie about dogs and people loved it"
|
||||
|
||||
# Lowercase should be allowed
|
||||
test_answer = textwrap.dedent(
|
||||
"""
|
||||
It seems many people love dogs
|
||||
quote: A dog is a man's best friend
|
||||
Quote: Air Bud was a movie about dogs and people loved it
|
||||
"""
|
||||
).strip()
|
||||
answer, quotes = separate_answer_quotes(test_answer)
|
||||
self.assertEqual(answer, "It seems many people love dogs")
|
||||
self.assertEqual(quotes[0], "A dog is a man's best friend") # type: ignore
|
||||
self.assertEqual(
|
||||
quotes[1], "Air Bud was a movie about dogs and people loved it" # type: ignore
|
||||
)
|
||||
# Test case 2: Lowercase 'quote' allowed
|
||||
test_answer = textwrap.dedent(
|
||||
"""
|
||||
It seems many people love dogs
|
||||
quote: A dog is a man's best friend
|
||||
Quote: Air Bud was a movie about dogs and people loved it
|
||||
"""
|
||||
).strip()
|
||||
answer, quotes = separate_answer_quotes(test_answer)
|
||||
assert answer == "It seems many people love dogs"
|
||||
assert isinstance(quotes, list)
|
||||
assert quotes[0] == "A dog is a man's best friend"
|
||||
assert quotes[1] == "Air Bud was a movie about dogs and people loved it"
|
||||
|
||||
# No Answer
|
||||
test_answer = textwrap.dedent(
|
||||
"""
|
||||
Quote: This one has no answer
|
||||
"""
|
||||
).strip()
|
||||
answer, quotes = separate_answer_quotes(test_answer)
|
||||
self.assertIsNone(answer)
|
||||
self.assertIsNone(quotes)
|
||||
# Test case 3: No Answer
|
||||
test_answer = textwrap.dedent(
|
||||
"""
|
||||
Quote: This one has no answer
|
||||
"""
|
||||
).strip()
|
||||
answer, quotes = separate_answer_quotes(test_answer)
|
||||
assert answer is None
|
||||
assert quotes is None
|
||||
|
||||
# Multiline Quote
|
||||
test_answer = textwrap.dedent(
|
||||
"""
|
||||
It seems many people love dogs
|
||||
quote: A well known saying is:
|
||||
A dog is a man's best friend
|
||||
Quote: Air Bud was a movie about dogs and people loved it
|
||||
"""
|
||||
).strip()
|
||||
answer, quotes = separate_answer_quotes(test_answer)
|
||||
self.assertEqual(answer, "It seems many people love dogs")
|
||||
self.assertEqual(
|
||||
quotes[0], "A well known saying is:\nA dog is a man's best friend" # type: ignore
|
||||
)
|
||||
self.assertEqual(
|
||||
quotes[1], "Air Bud was a movie about dogs and people loved it" # type: ignore
|
||||
)
|
||||
# Test case 4: Multiline Quote
|
||||
test_answer = textwrap.dedent(
|
||||
"""
|
||||
It seems many people love dogs
|
||||
quote: A well known saying is:
|
||||
A dog is a man's best friend
|
||||
Quote: Air Bud was a movie about dogs and people loved it
|
||||
"""
|
||||
).strip()
|
||||
answer, quotes = separate_answer_quotes(test_answer)
|
||||
assert answer == "It seems many people love dogs"
|
||||
assert isinstance(quotes, list)
|
||||
assert quotes[0] == "A well known saying is:\nA dog is a man's best friend"
|
||||
assert quotes[1] == "Air Bud was a movie about dogs and people loved it"
|
||||
|
||||
# Random patterns not picked up
|
||||
test_answer = textwrap.dedent(
|
||||
"""
|
||||
It seems many people love quote: dogs
|
||||
quote: Quote: A well known saying is:
|
||||
A dog is a man's best friend
|
||||
Quote: Answer: Air Bud was a movie about dogs and quote: people loved it
|
||||
"""
|
||||
).strip()
|
||||
answer, quotes = separate_answer_quotes(test_answer)
|
||||
self.assertEqual(answer, "It seems many people love quote: dogs")
|
||||
self.assertEqual(
|
||||
quotes[0], "Quote: A well known saying is:\nA dog is a man's best friend" # type: ignore
|
||||
)
|
||||
self.assertEqual(
|
||||
quotes[1], # type: ignore
|
||||
"Answer: Air Bud was a movie about dogs and quote: people loved it",
|
||||
)
|
||||
|
||||
@unittest.skip(
|
||||
"Using fuzzy match is too slow anyway, doesn't matter if it's broken"
|
||||
# Test case 5: Random patterns not picked up
|
||||
test_answer = textwrap.dedent(
|
||||
"""
|
||||
It seems many people love quote: dogs
|
||||
quote: Quote: A well known saying is:
|
||||
A dog is a man's best friend
|
||||
Quote: Answer: Air Bud was a movie about dogs and quote: people loved it
|
||||
"""
|
||||
).strip()
|
||||
answer, quotes = separate_answer_quotes(test_answer)
|
||||
assert answer == "It seems many people love quote: dogs"
|
||||
assert isinstance(quotes, list)
|
||||
assert quotes[0] == "Quote: A well known saying is:\nA dog is a man's best friend"
|
||||
assert (
|
||||
quotes[1] == "Answer: Air Bud was a movie about dogs and quote: people loved it"
|
||||
)
|
||||
def test_fuzzy_match_quotes_to_docs(self) -> None:
|
||||
chunk_0_text = textwrap.dedent(
|
||||
"""
|
||||
Here's a doc with some LINK embedded in the text
|
||||
THIS SECTION IS A LINK
|
||||
Some more text
|
||||
"""
|
||||
).strip()
|
||||
chunk_1_text = textwrap.dedent(
|
||||
"""
|
||||
Some completely different text here
|
||||
ANOTHER LINK embedded in this text
|
||||
ending in a DIFFERENT-LINK
|
||||
"""
|
||||
).strip()
|
||||
test_chunk_0 = InferenceChunk(
|
||||
document_id="test doc 0",
|
||||
source_type=DocumentSource.FILE,
|
||||
chunk_id=0,
|
||||
content=chunk_0_text,
|
||||
source_links={
|
||||
0: "doc 0 base",
|
||||
23: "first line link",
|
||||
49: "second line link",
|
||||
},
|
||||
blurb="anything",
|
||||
semantic_identifier="anything",
|
||||
section_continuation=False,
|
||||
recency_bias=1,
|
||||
boost=0,
|
||||
hidden=False,
|
||||
score=1,
|
||||
metadata={},
|
||||
match_highlights=[],
|
||||
updated_at=None,
|
||||
)
|
||||
test_chunk_1 = InferenceChunk(
|
||||
document_id="test doc 1",
|
||||
source_type=DocumentSource.FILE,
|
||||
chunk_id=0,
|
||||
content=chunk_1_text,
|
||||
source_links={0: "doc 1 base", 36: "2nd line link", 82: "last link"},
|
||||
blurb="whatever",
|
||||
semantic_identifier="whatever",
|
||||
section_continuation=False,
|
||||
recency_bias=1,
|
||||
boost=0,
|
||||
hidden=False,
|
||||
score=1,
|
||||
metadata={},
|
||||
match_highlights=[],
|
||||
updated_at=None,
|
||||
)
|
||||
|
||||
test_quotes = [
|
||||
"a doc with some", # Basic case
|
||||
"a doc with some LINK", # Should take the start of quote, even if a link is in it
|
||||
"a doc with some \nLINK", # Requires a newline deletion fuzzy match
|
||||
"a doc with some link", # Capitalization insensitive
|
||||
"embedded in this text", # Fuzzy match to first doc
|
||||
"SECTION IS A LINK", # Match exact link
|
||||
"some more text", # Match the end, after every link offset
|
||||
"different taxt", # Substitution
|
||||
"embedded in this texts", # Cannot fuzzy match to first doc, fuzzy match to second doc
|
||||
"DIFFERENT-LINK", # Exact link match at the end
|
||||
"Some complitali", # Too many edits, shouldn't match anything
|
||||
]
|
||||
results = match_quotes_to_docs(
|
||||
test_quotes, [test_chunk_0, test_chunk_1], fuzzy_search=True
|
||||
)
|
||||
self.assertEqual(
|
||||
results,
|
||||
{
|
||||
"a doc with some": {"document": "test doc 0", "link": "doc 0 base"},
|
||||
"a doc with some LINK": {
|
||||
"document": "test doc 0",
|
||||
"link": "doc 0 base",
|
||||
},
|
||||
"a doc with some \nLINK": {
|
||||
"document": "test doc 0",
|
||||
"link": "doc 0 base",
|
||||
},
|
||||
"a doc with some link": {
|
||||
"document": "test doc 0",
|
||||
"link": "doc 0 base",
|
||||
},
|
||||
"embedded in this text": {
|
||||
"document": "test doc 0",
|
||||
"link": "first line link",
|
||||
},
|
||||
"SECTION IS A LINK": {
|
||||
"document": "test doc 0",
|
||||
"link": "second line link",
|
||||
},
|
||||
"some more text": {
|
||||
"document": "test doc 0",
|
||||
"link": "second line link",
|
||||
},
|
||||
"different taxt": {"document": "test doc 1", "link": "doc 1 base"},
|
||||
"embedded in this texts": {
|
||||
"document": "test doc 1",
|
||||
"link": "2nd line link",
|
||||
},
|
||||
"DIFFERENT-LINK": {"document": "test doc 1", "link": "last link"},
|
||||
},
|
||||
)
|
||||
|
||||
|
||||
if __name__ == "__main__":
|
||||
unittest.main()
|
||||
@pytest.mark.skip(
|
||||
reason="Using fuzzy match is too slow anyway, doesn't matter if it's broken"
|
||||
)
|
||||
def test_fuzzy_match_quotes_to_docs() -> None:
|
||||
chunk_0_text = textwrap.dedent(
|
||||
"""
|
||||
Here's a doc with some LINK embedded in the text
|
||||
THIS SECTION IS A LINK
|
||||
Some more text
|
||||
"""
|
||||
).strip()
|
||||
chunk_1_text = textwrap.dedent(
|
||||
"""
|
||||
Some completely different text here
|
||||
ANOTHER LINK embedded in this text
|
||||
ending in a DIFFERENT-LINK
|
||||
"""
|
||||
).strip()
|
||||
test_chunk_0 = InferenceChunk(
|
||||
document_id="test doc 0",
|
||||
source_type=DocumentSource.FILE,
|
||||
chunk_id=0,
|
||||
content=chunk_0_text,
|
||||
source_links={
|
||||
0: "doc 0 base",
|
||||
23: "first line link",
|
||||
49: "second line link",
|
||||
},
|
||||
blurb="anything",
|
||||
semantic_identifier="anything",
|
||||
section_continuation=False,
|
||||
recency_bias=1,
|
||||
boost=0,
|
||||
hidden=False,
|
||||
score=1,
|
||||
metadata={},
|
||||
match_highlights=[],
|
||||
updated_at=None,
|
||||
)
|
||||
test_chunk_1 = InferenceChunk(
|
||||
document_id="test doc 1",
|
||||
source_type=DocumentSource.FILE,
|
||||
chunk_id=0,
|
||||
content=chunk_1_text,
|
||||
source_links={0: "doc 1 base", 36: "2nd line link", 82: "last link"},
|
||||
blurb="whatever",
|
||||
semantic_identifier="whatever",
|
||||
section_continuation=False,
|
||||
recency_bias=1,
|
||||
boost=0,
|
||||
hidden=False,
|
||||
score=1,
|
||||
metadata={},
|
||||
match_highlights=[],
|
||||
updated_at=None,
|
||||
)
|
||||
|
||||
test_quotes = [
|
||||
"a doc with some", # Basic case
|
||||
"a doc with some LINK", # Should take the start of quote, even if a link is in it
|
||||
"a doc with some \nLINK", # Requires a newline deletion fuzzy match
|
||||
"a doc with some link", # Capitalization insensitive
|
||||
"embedded in this text", # Fuzzy match to first doc
|
||||
"SECTION IS A LINK", # Match exact link
|
||||
"some more text", # Match the end, after every link offset
|
||||
"different taxt", # Substitution
|
||||
"embedded in this texts", # Cannot fuzzy match to first doc, fuzzy match to second doc
|
||||
"DIFFERENT-LINK", # Exact link match at the end
|
||||
"Some complitali", # Too many edits, shouldn't match anything
|
||||
]
|
||||
results = match_quotes_to_docs(
|
||||
test_quotes, [test_chunk_0, test_chunk_1], fuzzy_search=True
|
||||
)
|
||||
assert results == {
|
||||
"a doc with some": {"document": "test doc 0", "link": "doc 0 base"},
|
||||
"a doc with some LINK": {
|
||||
"document": "test doc 0",
|
||||
"link": "doc 0 base",
|
||||
},
|
||||
"a doc with some \nLINK": {
|
||||
"document": "test doc 0",
|
||||
"link": "doc 0 base",
|
||||
},
|
||||
"a doc with some link": {
|
||||
"document": "test doc 0",
|
||||
"link": "doc 0 base",
|
||||
},
|
||||
"embedded in this text": {
|
||||
"document": "test doc 0",
|
||||
"link": "first line link",
|
||||
},
|
||||
"SECTION IS A LINK": {
|
||||
"document": "test doc 0",
|
||||
"link": "second line link",
|
||||
},
|
||||
"some more text": {
|
||||
"document": "test doc 0",
|
||||
"link": "second line link",
|
||||
},
|
||||
"different taxt": {"document": "test doc 1", "link": "doc 1 base"},
|
||||
"embedded in this texts": {
|
||||
"document": "test doc 1",
|
||||
"link": "2nd line link",
|
||||
},
|
||||
"DIFFERENT-LINK": {"document": "test doc 1", "link": "last link"},
|
||||
}
|
||||
|
@ -0,0 +1,277 @@
|
||||
from datetime import datetime
|
||||
|
||||
import pytest
|
||||
|
||||
from danswer.chat.models import CitationInfo
|
||||
from danswer.chat.models import DanswerAnswerPiece
|
||||
from danswer.chat.models import LlmDoc
|
||||
from danswer.configs.constants import DocumentSource
|
||||
from danswer.llm.answering.stream_processing.citation_processing import (
|
||||
extract_citations_from_stream,
|
||||
)
|
||||
from danswer.llm.answering.stream_processing.utils import DocumentIdOrderMapping
|
||||
|
||||
|
||||
"""
|
||||
This module contains tests for the citation extraction functionality in Danswer.
|
||||
|
||||
The tests focus on the `extract_citations_from_stream` function, which processes
|
||||
a stream of tokens and extracts citations, replacing them with properly formatted
|
||||
versions including links where available.
|
||||
|
||||
Key components:
|
||||
- mock_docs: A list of mock LlmDoc objects used for testing.
|
||||
- mock_doc_mapping: A dictionary mapping document IDs to their ranks.
|
||||
- process_text: A helper function that simulates the citation extraction process.
|
||||
- test_citation_extraction: A parametrized test function covering various citation scenarios.
|
||||
|
||||
To add new test cases:
|
||||
1. Add a new tuple to the @pytest.mark.parametrize decorator of test_citation_extraction.
|
||||
2. Each tuple should contain:
|
||||
- A descriptive test name (string)
|
||||
- Input tokens (list of strings)
|
||||
- Expected output text (string)
|
||||
- Expected citations (list of document IDs)
|
||||
"""
|
||||
|
||||
|
||||
mock_docs = [
|
||||
LlmDoc(
|
||||
document_id=f"doc_{int(id/2)}",
|
||||
content="Document is a doc",
|
||||
blurb=f"Document #{id}",
|
||||
semantic_identifier=f"Doc {id}",
|
||||
source_type=DocumentSource.WEB,
|
||||
metadata={},
|
||||
updated_at=datetime.now(),
|
||||
link=f"https://{int(id/2)}.com" if int(id / 2) % 2 == 0 else None,
|
||||
source_links={0: "https://mintlify.com/docs/settings/broken-links"},
|
||||
)
|
||||
for id in range(10)
|
||||
]
|
||||
|
||||
mock_doc_mapping = {
|
||||
"doc_0": 1,
|
||||
"doc_1": 2,
|
||||
"doc_2": 3,
|
||||
"doc_3": 4,
|
||||
"doc_4": 5,
|
||||
"doc_5": 6,
|
||||
}
|
||||
|
||||
|
||||
@pytest.fixture
|
||||
def mock_data() -> tuple[list[LlmDoc], dict[str, int]]:
|
||||
return mock_docs, mock_doc_mapping
|
||||
|
||||
|
||||
def process_text(
|
||||
tokens: list[str], mock_data: tuple[list[LlmDoc], dict[str, int]]
|
||||
) -> tuple[str, list[CitationInfo]]:
|
||||
mock_docs, mock_doc_id_to_rank_map = mock_data
|
||||
mapping = DocumentIdOrderMapping(order_mapping=mock_doc_id_to_rank_map)
|
||||
result = list(
|
||||
extract_citations_from_stream(
|
||||
tokens=iter(tokens),
|
||||
context_docs=mock_docs,
|
||||
doc_id_to_rank_map=mapping,
|
||||
stop_stream=None,
|
||||
)
|
||||
)
|
||||
final_answer_text = ""
|
||||
citations = []
|
||||
for piece in result:
|
||||
if isinstance(piece, DanswerAnswerPiece):
|
||||
final_answer_text += piece.answer_piece or ""
|
||||
elif isinstance(piece, CitationInfo):
|
||||
citations.append(piece)
|
||||
return final_answer_text, citations
|
||||
|
||||
|
||||
@pytest.mark.parametrize(
|
||||
"test_name, input_tokens, expected_text, expected_citations",
|
||||
[
|
||||
(
|
||||
"Single citation",
|
||||
["Gro", "wth! [", "1", "]", "."],
|
||||
"Growth! [[1]](https://0.com).",
|
||||
["doc_0"],
|
||||
),
|
||||
(
|
||||
"Repeated citations",
|
||||
["Test! ", "[", "1", "]", ". And so", "me more ", "[", "2", "]", "."],
|
||||
"Test! [[1]](https://0.com). And some more [[1]](https://0.com).",
|
||||
["doc_0"],
|
||||
),
|
||||
(
|
||||
"Citations at sentence boundaries",
|
||||
[
|
||||
"Citation at the ",
|
||||
"end of a sen",
|
||||
"tence.",
|
||||
"[",
|
||||
"2",
|
||||
"]",
|
||||
" Another sen",
|
||||
"tence.",
|
||||
"[",
|
||||
"4",
|
||||
"]",
|
||||
],
|
||||
"Citation at the end of a sentence.[[1]](https://0.com) Another sentence.[[2]]()",
|
||||
["doc_0", "doc_1"],
|
||||
),
|
||||
(
|
||||
"Citations at beginning, middle, and end",
|
||||
[
|
||||
"[",
|
||||
"1",
|
||||
"]",
|
||||
" Citation at ",
|
||||
"the beginning. ",
|
||||
"[",
|
||||
"3",
|
||||
"]",
|
||||
" In the mid",
|
||||
"dle. At the end ",
|
||||
"[",
|
||||
"5",
|
||||
"]",
|
||||
".",
|
||||
],
|
||||
"[[1]](https://0.com) Citation at the beginning. [[2]]() In the middle. At the end [[3]](https://2.com).",
|
||||
["doc_0", "doc_1", "doc_2"],
|
||||
),
|
||||
(
|
||||
"Mixed valid and invalid citations",
|
||||
[
|
||||
"Mixed valid and in",
|
||||
"valid citations ",
|
||||
"[",
|
||||
"1",
|
||||
"]",
|
||||
"[",
|
||||
"99",
|
||||
"]",
|
||||
"[",
|
||||
"3",
|
||||
"]",
|
||||
"[",
|
||||
"100",
|
||||
"]",
|
||||
"[",
|
||||
"5",
|
||||
"]",
|
||||
".",
|
||||
],
|
||||
"Mixed valid and invalid citations [[1]](https://0.com)[99][[2]]()[100][[3]](https://2.com).",
|
||||
["doc_0", "doc_1", "doc_2"],
|
||||
),
|
||||
(
|
||||
"Hardest!",
|
||||
[
|
||||
"Multiple cit",
|
||||
"ations in one ",
|
||||
"sentence [",
|
||||
"1",
|
||||
"]",
|
||||
"[",
|
||||
"4",
|
||||
"]",
|
||||
"[",
|
||||
"5",
|
||||
"]",
|
||||
". ",
|
||||
],
|
||||
"Multiple citations in one sentence [[1]](https://0.com)[[2]]()[[3]](https://2.com).",
|
||||
["doc_0", "doc_1", "doc_2"],
|
||||
),
|
||||
(
|
||||
"Repeated citations with text",
|
||||
["[", "1", "]", "Aasf", "asda", "sff ", "[", "1", "]", " ."],
|
||||
"[[1]](https://0.com)Aasfasdasff [[1]](https://0.com) .",
|
||||
["doc_0"],
|
||||
),
|
||||
(
|
||||
"Consecutive identical citations!",
|
||||
[
|
||||
"Citations [",
|
||||
"1",
|
||||
"]",
|
||||
"[",
|
||||
"1]",
|
||||
"",
|
||||
"[2",
|
||||
"",
|
||||
"]",
|
||||
". ",
|
||||
],
|
||||
"Citations [[1]](https://0.com).",
|
||||
["doc_0"],
|
||||
),
|
||||
(
|
||||
"Consecutive identical citations!",
|
||||
[
|
||||
"test [1]tt[1]t",
|
||||
"",
|
||||
],
|
||||
"test [[1]](https://0.com)ttt",
|
||||
["doc_0"],
|
||||
),
|
||||
(
|
||||
"Consecutive identical citations!",
|
||||
[
|
||||
"test [1]t[1]t[1]",
|
||||
"",
|
||||
],
|
||||
"test [[1]](https://0.com)tt",
|
||||
["doc_0"],
|
||||
),
|
||||
(
|
||||
"Repeated citations with text",
|
||||
["[", "1", "]", "Aasf", "asda", "sff ", "[", "1", "]", " ."],
|
||||
"[[1]](https://0.com)Aasfasdasff [[1]](https://0.com) .",
|
||||
["doc_0"],
|
||||
),
|
||||
(
|
||||
"Repeated citations with text",
|
||||
["[1][", "1", "]t", "[2]"],
|
||||
"[[1]](https://0.com)t",
|
||||
["doc_0"],
|
||||
),
|
||||
(
|
||||
"Repeated citations with text",
|
||||
["[1][", "1", "]t]", "[2]"],
|
||||
"[[1]](https://0.com)t]",
|
||||
["doc_0"],
|
||||
),
|
||||
(
|
||||
"Repeated citations with text",
|
||||
["[1][", "3", "]t]", "[2]"],
|
||||
"[[1]](https://0.com)[[2]]()t]",
|
||||
["doc_0", "doc_1"],
|
||||
),
|
||||
(
|
||||
"Repeated citations with text",
|
||||
["[1", "][", "3", "]t]", "[2]"],
|
||||
"[[1]](https://0.com)[[2]]()t]",
|
||||
["doc_0", "doc_1"],
|
||||
),
|
||||
],
|
||||
)
|
||||
def test_citation_extraction(
|
||||
mock_data: tuple[list[LlmDoc], dict[str, int]],
|
||||
test_name: str,
|
||||
input_tokens: list[str],
|
||||
expected_text: str,
|
||||
expected_citations: list[str],
|
||||
) -> None:
|
||||
final_answer_text, citations = process_text(input_tokens, mock_data)
|
||||
assert (
|
||||
final_answer_text.strip() == expected_text.strip()
|
||||
), f"Test '{test_name}' failed: Final answer text does not match expected output."
|
||||
assert [
|
||||
citation.document_id for citation in citations
|
||||
] == expected_citations, (
|
||||
f"Test '{test_name}' failed: Citations do not match expected output."
|
||||
)
|
@ -1 +1 @@
|
||||
f1f2 1 1718910083.03085 wikipedia:en
|
||||
f1f2 1 1718910083.03085 wikipedia:en
|
Loading…
x
Reference in New Issue
Block a user