mirror of
https://github.com/danswer-ai/danswer.git
synced 2025-05-04 00:40:44 +02:00
Merge Sections Logic (#1801)
This commit is contained in:
parent
47a550221f
commit
08c6e821e7
21
.vscode/launch.template.jsonc
vendored
21
.vscode/launch.template.jsonc
vendored
@ -83,6 +83,7 @@
|
||||
"cwd": "${workspaceFolder}/backend",
|
||||
"envFile": "${workspaceFolder}/.env",
|
||||
"env": {
|
||||
"LOG_DANSWER_MODEL_INTERACTIONS": "True",
|
||||
"LOG_LEVEL": "DEBUG",
|
||||
"PYTHONUNBUFFERED": "1",
|
||||
"PYTHONPATH": "."
|
||||
@ -105,6 +106,24 @@
|
||||
"PYTHONUNBUFFERED": "1",
|
||||
"PYTHONPATH": "."
|
||||
}
|
||||
},
|
||||
{
|
||||
"name": "Pytest",
|
||||
"type": "python",
|
||||
"request": "launch",
|
||||
"module": "pytest",
|
||||
"cwd": "${workspaceFolder}/backend",
|
||||
"envFile": "${workspaceFolder}/.env",
|
||||
"env": {
|
||||
"LOG_LEVEL": "DEBUG",
|
||||
"PYTHONUNBUFFERED": "1",
|
||||
"PYTHONPATH": "."
|
||||
},
|
||||
"args": [
|
||||
"-v"
|
||||
// Specify a sepcific module/test to run or provide nothing to run all tests
|
||||
//"tests/unit/danswer/llm/answering/test_prune_and_merge.py"
|
||||
]
|
||||
}
|
||||
]
|
||||
}
|
||||
}
|
||||
|
@ -28,6 +28,10 @@ BASE_RECENCY_DECAY = 0.5
|
||||
FAVOR_RECENT_DECAY_MULTIPLIER = 2.0
|
||||
# Currently this next one is not configurable via env
|
||||
DISABLE_LLM_QUERY_ANSWERABILITY = QA_PROMPT_OVERRIDE == "weak"
|
||||
# For the highest matching base size chunk, how many chunks above and below do we pull in by default
|
||||
# Note this is not in any of the deployment configs yet
|
||||
CONTEXT_CHUNKS_ABOVE = int(os.environ.get("CONTEXT_CHUNKS_ABOVE") or 0)
|
||||
CONTEXT_CHUNKS_BELOW = int(os.environ.get("CONTEXT_CHUNKS_BELOW") or 0)
|
||||
# Whether the LLM should evaluate all of the document chunks passed in for usefulness
|
||||
# in relation to the user query
|
||||
DISABLE_LLM_CHUNK_FILTER = (
|
||||
@ -43,8 +47,6 @@ DISABLE_LLM_QUERY_REPHRASE = (
|
||||
# 1 edit per 20 characters, currently unused due to fuzzy match being too slow
|
||||
QUOTE_ALLOWED_ERROR_PERCENT = 0.05
|
||||
QA_TIMEOUT = int(os.environ.get("QA_TIMEOUT") or "60") # 60 seconds
|
||||
# Include additional document/chunk metadata in prompt to GenerativeAI
|
||||
INCLUDE_METADATA = False
|
||||
# Keyword Search Drop Stopwords
|
||||
# If user has changed the default model, would most likely be to use a multilingual
|
||||
# model, the stopwords are NLTK english stopwords so then we would want to not drop the keywords
|
||||
|
@ -1,7 +1,10 @@
|
||||
import json
|
||||
from collections import defaultdict
|
||||
from copy import deepcopy
|
||||
from typing import TypeVar
|
||||
|
||||
from pydantic import BaseModel
|
||||
|
||||
from danswer.chat.models import (
|
||||
LlmDoc,
|
||||
)
|
||||
@ -31,6 +34,36 @@ class PruningError(Exception):
|
||||
pass
|
||||
|
||||
|
||||
class ChunkRange(BaseModel):
|
||||
chunks: list[InferenceChunk]
|
||||
start: int
|
||||
end: int
|
||||
|
||||
|
||||
def merge_chunk_intervals(chunk_ranges: list[ChunkRange]) -> list[ChunkRange]:
|
||||
"""
|
||||
This acts on a single document to merge the overlapping ranges of chunks
|
||||
Algo explained here for easy understanding: https://leetcode.com/problems/merge-intervals
|
||||
|
||||
NOTE: this is used to merge chunk ranges for retrieving the right chunk_ids against the
|
||||
document index, this does not merge the actual contents so it should not be used to actually
|
||||
merge chunks post retrieval.
|
||||
"""
|
||||
sorted_ranges = sorted(chunk_ranges, key=lambda x: x.start)
|
||||
|
||||
combined_ranges: list[ChunkRange] = []
|
||||
|
||||
for new_chunk_range in sorted_ranges:
|
||||
if not combined_ranges or combined_ranges[-1].end < new_chunk_range.start - 1:
|
||||
combined_ranges.append(new_chunk_range)
|
||||
else:
|
||||
current_range = combined_ranges[-1]
|
||||
current_range.end = max(current_range.end, new_chunk_range.end)
|
||||
current_range.chunks.extend(new_chunk_range.chunks)
|
||||
|
||||
return combined_ranges
|
||||
|
||||
|
||||
def _compute_limit(
|
||||
prompt_config: PromptConfig,
|
||||
llm_config: LLMConfig,
|
||||
@ -219,6 +252,7 @@ def prune_sections(
|
||||
question: str,
|
||||
document_pruning_config: DocumentPruningConfig,
|
||||
) -> list[InferenceSection]:
|
||||
# Assumes the sections are score ordered with highest first
|
||||
if section_relevance_list is not None:
|
||||
assert len(sections) == len(section_relevance_list)
|
||||
|
||||
@ -241,6 +275,67 @@ def prune_sections(
|
||||
)
|
||||
|
||||
|
||||
def _merge_doc_chunks(chunks: list[InferenceChunk]) -> InferenceSection:
|
||||
# Assuming there are no duplicates by this point
|
||||
sorted_chunks = sorted(chunks, key=lambda x: x.chunk_id)
|
||||
|
||||
center_chunk = max(
|
||||
chunks, key=lambda x: x.score if x.score is not None else float("-inf")
|
||||
)
|
||||
|
||||
merged_content = []
|
||||
for i, chunk in enumerate(sorted_chunks):
|
||||
if i > 0:
|
||||
prev_chunk_id = sorted_chunks[i - 1].chunk_id
|
||||
if chunk.chunk_id == prev_chunk_id + 1:
|
||||
merged_content.append("\n")
|
||||
else:
|
||||
merged_content.append("\n\n...\n\n")
|
||||
merged_content.append(chunk.content)
|
||||
|
||||
combined_content = "".join(merged_content)
|
||||
|
||||
return InferenceSection(
|
||||
center_chunk=center_chunk,
|
||||
chunks=sorted_chunks,
|
||||
combined_content=combined_content,
|
||||
)
|
||||
|
||||
|
||||
def _merge_sections(sections: list[InferenceSection]) -> list[InferenceSection]:
|
||||
docs_map: dict[str, dict[int, InferenceChunk]] = defaultdict(dict)
|
||||
doc_order: dict[str, int] = {}
|
||||
for index, section in enumerate(sections):
|
||||
if section.center_chunk.document_id not in doc_order:
|
||||
doc_order[section.center_chunk.document_id] = index
|
||||
for chunk in [section.center_chunk] + section.chunks:
|
||||
chunks_map = docs_map[section.center_chunk.document_id]
|
||||
existing_chunk = chunks_map.get(chunk.chunk_id)
|
||||
if (
|
||||
existing_chunk is None
|
||||
or existing_chunk.score is None
|
||||
or chunk.score is not None
|
||||
and chunk.score > existing_chunk.score
|
||||
):
|
||||
chunks_map[chunk.chunk_id] = chunk
|
||||
|
||||
new_sections = []
|
||||
for section_chunks in docs_map.values():
|
||||
new_sections.append(_merge_doc_chunks(chunks=list(section_chunks.values())))
|
||||
|
||||
# Sort by highest score, then by original document order
|
||||
# It is now 1 large section per doc, the center chunk being the one with the highest score
|
||||
new_sections.sort(
|
||||
key=lambda x: (
|
||||
x.center_chunk.score if x.center_chunk.score is not None else 0,
|
||||
-1 * doc_order[x.center_chunk.document_id],
|
||||
),
|
||||
reverse=True,
|
||||
)
|
||||
|
||||
return new_sections
|
||||
|
||||
|
||||
def prune_and_merge_sections(
|
||||
sections: list[InferenceSection],
|
||||
section_relevance_list: list[bool] | None,
|
||||
@ -249,6 +344,7 @@ def prune_and_merge_sections(
|
||||
question: str,
|
||||
document_pruning_config: DocumentPruningConfig,
|
||||
) -> list[InferenceSection]:
|
||||
# Assumes the sections are score ordered with highest first
|
||||
remaining_sections = prune_sections(
|
||||
sections=sections,
|
||||
section_relevance_list=section_relevance_list,
|
||||
@ -257,6 +353,7 @@ def prune_and_merge_sections(
|
||||
question=question,
|
||||
document_pruning_config=document_pruning_config,
|
||||
)
|
||||
# TODO add the actual section combination logic
|
||||
|
||||
return remaining_sections
|
||||
merged_sections = _merge_sections(sections=remaining_sections)
|
||||
|
||||
return merged_sections
|
||||
|
@ -4,6 +4,8 @@ from typing import Any
|
||||
from pydantic import BaseModel
|
||||
from pydantic import validator
|
||||
|
||||
from danswer.configs.chat_configs import CONTEXT_CHUNKS_ABOVE
|
||||
from danswer.configs.chat_configs import CONTEXT_CHUNKS_BELOW
|
||||
from danswer.configs.chat_configs import DISABLE_LLM_CHUNK_FILTER
|
||||
from danswer.configs.chat_configs import HYBRID_ALPHA
|
||||
from danswer.configs.chat_configs import NUM_RERANKED_RESULTS
|
||||
@ -47,8 +49,8 @@ class ChunkMetric(BaseModel):
|
||||
class ChunkContext(BaseModel):
|
||||
# Additional surrounding context options, if full doc, then chunks are deduped
|
||||
# If surrounding context overlap, it is combined into one
|
||||
chunks_above: int = 0
|
||||
chunks_below: int = 0
|
||||
chunks_above: int = CONTEXT_CHUNKS_ABOVE
|
||||
chunks_below: int = CONTEXT_CHUNKS_BELOW
|
||||
full_doc: bool = False
|
||||
|
||||
@validator("chunks_above", "chunks_below", pre=True, each_item=False)
|
||||
|
@ -3,13 +3,14 @@ from collections.abc import Callable
|
||||
from collections.abc import Iterator
|
||||
from typing import cast
|
||||
|
||||
from pydantic import BaseModel
|
||||
from sqlalchemy.orm import Session
|
||||
|
||||
from danswer.configs.chat_configs import MULTILINGUAL_QUERY_EXPANSION
|
||||
from danswer.db.embedding_model import get_current_db_embedding_model
|
||||
from danswer.db.models import User
|
||||
from danswer.document_index.factory import get_default_document_index
|
||||
from danswer.llm.answering.prune_and_merge import ChunkRange
|
||||
from danswer.llm.answering.prune_and_merge import merge_chunk_intervals
|
||||
from danswer.llm.interfaces import LLM
|
||||
from danswer.search.enums import QueryFlow
|
||||
from danswer.search.enums import SearchType
|
||||
@ -30,32 +31,6 @@ from danswer.utils.threadpool_concurrency import run_functions_tuples_in_paralle
|
||||
logger = setup_logger()
|
||||
|
||||
|
||||
class ChunkRange(BaseModel):
|
||||
chunks: list[InferenceChunk]
|
||||
start: int
|
||||
end: int
|
||||
combined_content: str | None = None
|
||||
|
||||
|
||||
def merge_chunk_intervals(chunk_ranges: list[ChunkRange]) -> list[ChunkRange]:
|
||||
"""This acts on a single document to merge the overlapping ranges of sections
|
||||
Algo explained here for easy understanding: https://leetcode.com/problems/merge-intervals
|
||||
"""
|
||||
sorted_ranges = sorted(chunk_ranges, key=lambda x: x.start)
|
||||
|
||||
combined_ranges: list[ChunkRange] = []
|
||||
|
||||
for new_chunk_range in sorted_ranges:
|
||||
if not combined_ranges or combined_ranges[-1].end < new_chunk_range.start:
|
||||
combined_ranges.append(new_chunk_range)
|
||||
else:
|
||||
current_range = combined_ranges[-1]
|
||||
current_range.end = max(current_range.end, new_chunk_range.end)
|
||||
current_range.chunks.extend(new_chunk_range.chunks)
|
||||
|
||||
return combined_ranges
|
||||
|
||||
|
||||
class SearchPipeline:
|
||||
def __init__(
|
||||
self,
|
||||
|
229
backend/tests/unit/danswer/llm/answering/test_prune_and_merge.py
Normal file
229
backend/tests/unit/danswer/llm/answering/test_prune_and_merge.py
Normal file
@ -0,0 +1,229 @@
|
||||
import pytest
|
||||
|
||||
from danswer.configs.constants import DocumentSource
|
||||
from danswer.llm.answering.prune_and_merge import _merge_sections
|
||||
from danswer.search.models import InferenceChunk
|
||||
from danswer.search.models import InferenceSection
|
||||
|
||||
|
||||
# This large test accounts for all of the following:
|
||||
# 1. Merging of adjacent sections
|
||||
# 2. Merging of non-adjacent sections
|
||||
# 3. Merging of sections where there are multiple documents
|
||||
# 4. Verifying the contents of merged sections
|
||||
# 5. Verifying the order/score of the merged sections
|
||||
|
||||
|
||||
def create_inference_chunk(
|
||||
document_id: str, chunk_id: int, content: str, score: float | None
|
||||
) -> InferenceChunk:
|
||||
"""
|
||||
Create an InferenceChunk with hardcoded values for testing purposes.
|
||||
"""
|
||||
return InferenceChunk(
|
||||
chunk_id=chunk_id,
|
||||
document_id=document_id,
|
||||
semantic_identifier=f"{document_id}_{chunk_id}",
|
||||
blurb=f"{document_id}_{chunk_id}",
|
||||
content=content,
|
||||
source_links={0: "fake_link"},
|
||||
section_continuation=False,
|
||||
source_type=DocumentSource.WEB,
|
||||
boost=0,
|
||||
recency_bias=1.0,
|
||||
score=score,
|
||||
hidden=False,
|
||||
metadata={},
|
||||
match_highlights=[],
|
||||
updated_at=None,
|
||||
)
|
||||
|
||||
|
||||
# Document 1, top connected sections
|
||||
DOC_1_FILLER_1 = create_inference_chunk("doc1", 2, "Content 2", 1.0)
|
||||
DOC_1_FILLER_2 = create_inference_chunk("doc1", 3, "Content 3", 2.0)
|
||||
DOC_1_TOP_CHUNK = create_inference_chunk("doc1", 4, "Content 4", None)
|
||||
DOC_1_MID_CHUNK = create_inference_chunk("doc1", 5, "Content 5", 4.0)
|
||||
DOC_1_FILLER_3 = create_inference_chunk("doc1", 6, "Content 6", 5.0)
|
||||
DOC_1_FILLER_4 = create_inference_chunk("doc1", 7, "Content 7", 6.0)
|
||||
# This chunk below has the top score for testing
|
||||
DOC_1_BOTTOM_CHUNK = create_inference_chunk("doc1", 8, "Content 8", 70.0)
|
||||
DOC_1_FILLER_5 = create_inference_chunk("doc1", 9, "Content 9", None)
|
||||
DOC_1_FILLER_6 = create_inference_chunk("doc1", 10, "Content 10", 9.0)
|
||||
# Document 1, separate section
|
||||
DOC_1_FILLER_7 = create_inference_chunk("doc1", 13, "Content 13", 10.0)
|
||||
DOC_1_FILLER_8 = create_inference_chunk("doc1", 14, "Content 14", 11.0)
|
||||
DOC_1_DISCONNECTED = create_inference_chunk("doc1", 15, "Content 15", 12.0)
|
||||
DOC_1_FILLER_9 = create_inference_chunk("doc1", 16, "Content 16", 13.0)
|
||||
DOC_1_FILLER_10 = create_inference_chunk("doc1", 17, "Content 17", 14.0)
|
||||
# Document 2
|
||||
DOC_2_FILLER_1 = create_inference_chunk("doc2", 1, "Doc 2 Content 1", 15.0)
|
||||
DOC_2_FILLER_2 = create_inference_chunk("doc2", 2, "Doc 2 Content 2", 16.0)
|
||||
# This chunk below has top score for testing
|
||||
DOC_2_TOP_CHUNK = create_inference_chunk("doc2", 3, "Doc 2 Content 3", 170.0)
|
||||
DOC_2_FILLER_3 = create_inference_chunk("doc2", 4, "Doc 2 Content 4", 18.0)
|
||||
DOC_2_BOTTOM_CHUNK = create_inference_chunk("doc2", 5, "Doc 2 Content 5", 19.0)
|
||||
DOC_2_FILLER_4 = create_inference_chunk("doc2", 6, "Doc 2 Content 6", 20.0)
|
||||
DOC_2_FILLER_5 = create_inference_chunk("doc2", 7, "Doc 2 Content 7", 21.0)
|
||||
|
||||
|
||||
# Doc 2 has the highest score so it comes first
|
||||
EXPECTED_CONTENT_1 = """
|
||||
Doc 2 Content 1
|
||||
Doc 2 Content 2
|
||||
Doc 2 Content 3
|
||||
Doc 2 Content 4
|
||||
Doc 2 Content 5
|
||||
Doc 2 Content 6
|
||||
Doc 2 Content 7
|
||||
""".strip()
|
||||
|
||||
|
||||
EXPECTED_CONTENT_2 = """
|
||||
Content 2
|
||||
Content 3
|
||||
Content 4
|
||||
Content 5
|
||||
Content 6
|
||||
Content 7
|
||||
Content 8
|
||||
Content 9
|
||||
Content 10
|
||||
|
||||
...
|
||||
|
||||
Content 13
|
||||
Content 14
|
||||
Content 15
|
||||
Content 16
|
||||
Content 17
|
||||
""".strip()
|
||||
|
||||
|
||||
@pytest.mark.parametrize(
|
||||
"sections,expected_contents,expected_center_chunks",
|
||||
[
|
||||
(
|
||||
# Sections
|
||||
[
|
||||
# Document 1, top/middle/bot connected + disconnected section
|
||||
InferenceSection(
|
||||
center_chunk=DOC_1_TOP_CHUNK,
|
||||
chunks=[
|
||||
DOC_1_FILLER_1,
|
||||
DOC_1_FILLER_2,
|
||||
DOC_1_TOP_CHUNK,
|
||||
DOC_1_MID_CHUNK,
|
||||
DOC_1_FILLER_3,
|
||||
],
|
||||
combined_content="N/A", # Not used
|
||||
),
|
||||
InferenceSection(
|
||||
center_chunk=DOC_1_MID_CHUNK,
|
||||
chunks=[
|
||||
DOC_1_FILLER_2,
|
||||
DOC_1_TOP_CHUNK,
|
||||
DOC_1_MID_CHUNK,
|
||||
DOC_1_FILLER_3,
|
||||
DOC_1_FILLER_4,
|
||||
],
|
||||
combined_content="N/A",
|
||||
),
|
||||
InferenceSection(
|
||||
center_chunk=DOC_1_BOTTOM_CHUNK,
|
||||
chunks=[
|
||||
DOC_1_FILLER_3,
|
||||
DOC_1_FILLER_4,
|
||||
DOC_1_BOTTOM_CHUNK,
|
||||
DOC_1_FILLER_5,
|
||||
DOC_1_FILLER_6,
|
||||
],
|
||||
combined_content="N/A",
|
||||
),
|
||||
InferenceSection(
|
||||
center_chunk=DOC_1_DISCONNECTED,
|
||||
chunks=[
|
||||
DOC_1_FILLER_7,
|
||||
DOC_1_FILLER_8,
|
||||
DOC_1_DISCONNECTED,
|
||||
DOC_1_FILLER_9,
|
||||
DOC_1_FILLER_10,
|
||||
],
|
||||
combined_content="N/A",
|
||||
),
|
||||
InferenceSection(
|
||||
center_chunk=DOC_2_TOP_CHUNK,
|
||||
chunks=[
|
||||
DOC_2_FILLER_1,
|
||||
DOC_2_FILLER_2,
|
||||
DOC_2_TOP_CHUNK,
|
||||
DOC_2_FILLER_3,
|
||||
DOC_2_BOTTOM_CHUNK,
|
||||
],
|
||||
combined_content="N/A",
|
||||
),
|
||||
InferenceSection(
|
||||
center_chunk=DOC_2_BOTTOM_CHUNK,
|
||||
chunks=[
|
||||
DOC_2_TOP_CHUNK,
|
||||
DOC_2_FILLER_3,
|
||||
DOC_2_BOTTOM_CHUNK,
|
||||
DOC_2_FILLER_4,
|
||||
DOC_2_FILLER_5,
|
||||
],
|
||||
combined_content="N/A",
|
||||
),
|
||||
],
|
||||
# Expected Content
|
||||
[EXPECTED_CONTENT_1, EXPECTED_CONTENT_2],
|
||||
# Expected Center Chunks (highest scores)
|
||||
[DOC_2_TOP_CHUNK, DOC_1_BOTTOM_CHUNK],
|
||||
),
|
||||
],
|
||||
)
|
||||
def test_merge_sections(
|
||||
sections: list[InferenceSection],
|
||||
expected_contents: list[str],
|
||||
expected_center_chunks: list[InferenceChunk],
|
||||
) -> None:
|
||||
sections.sort(key=lambda section: section.center_chunk.score or 0, reverse=True)
|
||||
merged_sections = _merge_sections(sections)
|
||||
assert merged_sections[0].combined_content == expected_contents[0]
|
||||
assert merged_sections[1].combined_content == expected_contents[1]
|
||||
assert merged_sections[0].center_chunk == expected_center_chunks[0]
|
||||
assert merged_sections[1].center_chunk == expected_center_chunks[1]
|
||||
|
||||
|
||||
@pytest.mark.parametrize(
|
||||
"sections,expected_content,expected_center_chunk",
|
||||
[
|
||||
(
|
||||
# Sections
|
||||
[
|
||||
InferenceSection(
|
||||
center_chunk=DOC_1_TOP_CHUNK,
|
||||
chunks=[DOC_1_TOP_CHUNK],
|
||||
combined_content="N/A", # Not used
|
||||
),
|
||||
InferenceSection(
|
||||
center_chunk=DOC_1_MID_CHUNK,
|
||||
chunks=[DOC_1_MID_CHUNK],
|
||||
combined_content="N/A",
|
||||
),
|
||||
],
|
||||
# Expected Content
|
||||
"Content 4\nContent 5",
|
||||
# Expected Center Chunks (highest scores)
|
||||
DOC_1_MID_CHUNK,
|
||||
),
|
||||
],
|
||||
)
|
||||
def test_merge_minimal_sections(
|
||||
sections: list[InferenceSection],
|
||||
expected_content: str,
|
||||
expected_center_chunk: InferenceChunk,
|
||||
) -> None:
|
||||
sections.sort(key=lambda section: section.center_chunk.score or 0, reverse=True)
|
||||
merged_sections = _merge_sections(sections)
|
||||
assert merged_sections[0].combined_content == expected_content
|
||||
assert merged_sections[0].center_chunk == expected_center_chunk
|
Loading…
x
Reference in New Issue
Block a user