Merge Sections Logic (#1801)

This commit is contained in:
Yuhong Sun 2024-07-10 20:14:02 -07:00 committed by GitHub
parent 47a550221f
commit 08c6e821e7
No known key found for this signature in database
GPG Key ID: B5690EEEBB952194
6 changed files with 358 additions and 34 deletions

View File

@ -83,6 +83,7 @@
"cwd": "${workspaceFolder}/backend",
"envFile": "${workspaceFolder}/.env",
"env": {
"LOG_DANSWER_MODEL_INTERACTIONS": "True",
"LOG_LEVEL": "DEBUG",
"PYTHONUNBUFFERED": "1",
"PYTHONPATH": "."
@ -105,6 +106,24 @@
"PYTHONUNBUFFERED": "1",
"PYTHONPATH": "."
}
},
{
"name": "Pytest",
"type": "python",
"request": "launch",
"module": "pytest",
"cwd": "${workspaceFolder}/backend",
"envFile": "${workspaceFolder}/.env",
"env": {
"LOG_LEVEL": "DEBUG",
"PYTHONUNBUFFERED": "1",
"PYTHONPATH": "."
},
"args": [
"-v"
// Specify a sepcific module/test to run or provide nothing to run all tests
//"tests/unit/danswer/llm/answering/test_prune_and_merge.py"
]
}
]
}

View File

@ -28,6 +28,10 @@ BASE_RECENCY_DECAY = 0.5
FAVOR_RECENT_DECAY_MULTIPLIER = 2.0
# Currently this next one is not configurable via env
DISABLE_LLM_QUERY_ANSWERABILITY = QA_PROMPT_OVERRIDE == "weak"
# For the highest matching base size chunk, how many chunks above and below do we pull in by default
# Note this is not in any of the deployment configs yet
CONTEXT_CHUNKS_ABOVE = int(os.environ.get("CONTEXT_CHUNKS_ABOVE") or 0)
CONTEXT_CHUNKS_BELOW = int(os.environ.get("CONTEXT_CHUNKS_BELOW") or 0)
# Whether the LLM should evaluate all of the document chunks passed in for usefulness
# in relation to the user query
DISABLE_LLM_CHUNK_FILTER = (
@ -43,8 +47,6 @@ DISABLE_LLM_QUERY_REPHRASE = (
# 1 edit per 20 characters, currently unused due to fuzzy match being too slow
QUOTE_ALLOWED_ERROR_PERCENT = 0.05
QA_TIMEOUT = int(os.environ.get("QA_TIMEOUT") or "60") # 60 seconds
# Include additional document/chunk metadata in prompt to GenerativeAI
INCLUDE_METADATA = False
# Keyword Search Drop Stopwords
# If user has changed the default model, would most likely be to use a multilingual
# model, the stopwords are NLTK english stopwords so then we would want to not drop the keywords

View File

@ -1,7 +1,10 @@
import json
from collections import defaultdict
from copy import deepcopy
from typing import TypeVar
from pydantic import BaseModel
from danswer.chat.models import (
LlmDoc,
)
@ -31,6 +34,36 @@ class PruningError(Exception):
pass
class ChunkRange(BaseModel):
chunks: list[InferenceChunk]
start: int
end: int
def merge_chunk_intervals(chunk_ranges: list[ChunkRange]) -> list[ChunkRange]:
"""
This acts on a single document to merge the overlapping ranges of chunks
Algo explained here for easy understanding: https://leetcode.com/problems/merge-intervals
NOTE: this is used to merge chunk ranges for retrieving the right chunk_ids against the
document index, this does not merge the actual contents so it should not be used to actually
merge chunks post retrieval.
"""
sorted_ranges = sorted(chunk_ranges, key=lambda x: x.start)
combined_ranges: list[ChunkRange] = []
for new_chunk_range in sorted_ranges:
if not combined_ranges or combined_ranges[-1].end < new_chunk_range.start - 1:
combined_ranges.append(new_chunk_range)
else:
current_range = combined_ranges[-1]
current_range.end = max(current_range.end, new_chunk_range.end)
current_range.chunks.extend(new_chunk_range.chunks)
return combined_ranges
def _compute_limit(
prompt_config: PromptConfig,
llm_config: LLMConfig,
@ -219,6 +252,7 @@ def prune_sections(
question: str,
document_pruning_config: DocumentPruningConfig,
) -> list[InferenceSection]:
# Assumes the sections are score ordered with highest first
if section_relevance_list is not None:
assert len(sections) == len(section_relevance_list)
@ -241,6 +275,67 @@ def prune_sections(
)
def _merge_doc_chunks(chunks: list[InferenceChunk]) -> InferenceSection:
# Assuming there are no duplicates by this point
sorted_chunks = sorted(chunks, key=lambda x: x.chunk_id)
center_chunk = max(
chunks, key=lambda x: x.score if x.score is not None else float("-inf")
)
merged_content = []
for i, chunk in enumerate(sorted_chunks):
if i > 0:
prev_chunk_id = sorted_chunks[i - 1].chunk_id
if chunk.chunk_id == prev_chunk_id + 1:
merged_content.append("\n")
else:
merged_content.append("\n\n...\n\n")
merged_content.append(chunk.content)
combined_content = "".join(merged_content)
return InferenceSection(
center_chunk=center_chunk,
chunks=sorted_chunks,
combined_content=combined_content,
)
def _merge_sections(sections: list[InferenceSection]) -> list[InferenceSection]:
docs_map: dict[str, dict[int, InferenceChunk]] = defaultdict(dict)
doc_order: dict[str, int] = {}
for index, section in enumerate(sections):
if section.center_chunk.document_id not in doc_order:
doc_order[section.center_chunk.document_id] = index
for chunk in [section.center_chunk] + section.chunks:
chunks_map = docs_map[section.center_chunk.document_id]
existing_chunk = chunks_map.get(chunk.chunk_id)
if (
existing_chunk is None
or existing_chunk.score is None
or chunk.score is not None
and chunk.score > existing_chunk.score
):
chunks_map[chunk.chunk_id] = chunk
new_sections = []
for section_chunks in docs_map.values():
new_sections.append(_merge_doc_chunks(chunks=list(section_chunks.values())))
# Sort by highest score, then by original document order
# It is now 1 large section per doc, the center chunk being the one with the highest score
new_sections.sort(
key=lambda x: (
x.center_chunk.score if x.center_chunk.score is not None else 0,
-1 * doc_order[x.center_chunk.document_id],
),
reverse=True,
)
return new_sections
def prune_and_merge_sections(
sections: list[InferenceSection],
section_relevance_list: list[bool] | None,
@ -249,6 +344,7 @@ def prune_and_merge_sections(
question: str,
document_pruning_config: DocumentPruningConfig,
) -> list[InferenceSection]:
# Assumes the sections are score ordered with highest first
remaining_sections = prune_sections(
sections=sections,
section_relevance_list=section_relevance_list,
@ -257,6 +353,7 @@ def prune_and_merge_sections(
question=question,
document_pruning_config=document_pruning_config,
)
# TODO add the actual section combination logic
return remaining_sections
merged_sections = _merge_sections(sections=remaining_sections)
return merged_sections

View File

@ -4,6 +4,8 @@ from typing import Any
from pydantic import BaseModel
from pydantic import validator
from danswer.configs.chat_configs import CONTEXT_CHUNKS_ABOVE
from danswer.configs.chat_configs import CONTEXT_CHUNKS_BELOW
from danswer.configs.chat_configs import DISABLE_LLM_CHUNK_FILTER
from danswer.configs.chat_configs import HYBRID_ALPHA
from danswer.configs.chat_configs import NUM_RERANKED_RESULTS
@ -47,8 +49,8 @@ class ChunkMetric(BaseModel):
class ChunkContext(BaseModel):
# Additional surrounding context options, if full doc, then chunks are deduped
# If surrounding context overlap, it is combined into one
chunks_above: int = 0
chunks_below: int = 0
chunks_above: int = CONTEXT_CHUNKS_ABOVE
chunks_below: int = CONTEXT_CHUNKS_BELOW
full_doc: bool = False
@validator("chunks_above", "chunks_below", pre=True, each_item=False)

View File

@ -3,13 +3,14 @@ from collections.abc import Callable
from collections.abc import Iterator
from typing import cast
from pydantic import BaseModel
from sqlalchemy.orm import Session
from danswer.configs.chat_configs import MULTILINGUAL_QUERY_EXPANSION
from danswer.db.embedding_model import get_current_db_embedding_model
from danswer.db.models import User
from danswer.document_index.factory import get_default_document_index
from danswer.llm.answering.prune_and_merge import ChunkRange
from danswer.llm.answering.prune_and_merge import merge_chunk_intervals
from danswer.llm.interfaces import LLM
from danswer.search.enums import QueryFlow
from danswer.search.enums import SearchType
@ -30,32 +31,6 @@ from danswer.utils.threadpool_concurrency import run_functions_tuples_in_paralle
logger = setup_logger()
class ChunkRange(BaseModel):
chunks: list[InferenceChunk]
start: int
end: int
combined_content: str | None = None
def merge_chunk_intervals(chunk_ranges: list[ChunkRange]) -> list[ChunkRange]:
"""This acts on a single document to merge the overlapping ranges of sections
Algo explained here for easy understanding: https://leetcode.com/problems/merge-intervals
"""
sorted_ranges = sorted(chunk_ranges, key=lambda x: x.start)
combined_ranges: list[ChunkRange] = []
for new_chunk_range in sorted_ranges:
if not combined_ranges or combined_ranges[-1].end < new_chunk_range.start:
combined_ranges.append(new_chunk_range)
else:
current_range = combined_ranges[-1]
current_range.end = max(current_range.end, new_chunk_range.end)
current_range.chunks.extend(new_chunk_range.chunks)
return combined_ranges
class SearchPipeline:
def __init__(
self,

View File

@ -0,0 +1,229 @@
import pytest
from danswer.configs.constants import DocumentSource
from danswer.llm.answering.prune_and_merge import _merge_sections
from danswer.search.models import InferenceChunk
from danswer.search.models import InferenceSection
# This large test accounts for all of the following:
# 1. Merging of adjacent sections
# 2. Merging of non-adjacent sections
# 3. Merging of sections where there are multiple documents
# 4. Verifying the contents of merged sections
# 5. Verifying the order/score of the merged sections
def create_inference_chunk(
document_id: str, chunk_id: int, content: str, score: float | None
) -> InferenceChunk:
"""
Create an InferenceChunk with hardcoded values for testing purposes.
"""
return InferenceChunk(
chunk_id=chunk_id,
document_id=document_id,
semantic_identifier=f"{document_id}_{chunk_id}",
blurb=f"{document_id}_{chunk_id}",
content=content,
source_links={0: "fake_link"},
section_continuation=False,
source_type=DocumentSource.WEB,
boost=0,
recency_bias=1.0,
score=score,
hidden=False,
metadata={},
match_highlights=[],
updated_at=None,
)
# Document 1, top connected sections
DOC_1_FILLER_1 = create_inference_chunk("doc1", 2, "Content 2", 1.0)
DOC_1_FILLER_2 = create_inference_chunk("doc1", 3, "Content 3", 2.0)
DOC_1_TOP_CHUNK = create_inference_chunk("doc1", 4, "Content 4", None)
DOC_1_MID_CHUNK = create_inference_chunk("doc1", 5, "Content 5", 4.0)
DOC_1_FILLER_3 = create_inference_chunk("doc1", 6, "Content 6", 5.0)
DOC_1_FILLER_4 = create_inference_chunk("doc1", 7, "Content 7", 6.0)
# This chunk below has the top score for testing
DOC_1_BOTTOM_CHUNK = create_inference_chunk("doc1", 8, "Content 8", 70.0)
DOC_1_FILLER_5 = create_inference_chunk("doc1", 9, "Content 9", None)
DOC_1_FILLER_6 = create_inference_chunk("doc1", 10, "Content 10", 9.0)
# Document 1, separate section
DOC_1_FILLER_7 = create_inference_chunk("doc1", 13, "Content 13", 10.0)
DOC_1_FILLER_8 = create_inference_chunk("doc1", 14, "Content 14", 11.0)
DOC_1_DISCONNECTED = create_inference_chunk("doc1", 15, "Content 15", 12.0)
DOC_1_FILLER_9 = create_inference_chunk("doc1", 16, "Content 16", 13.0)
DOC_1_FILLER_10 = create_inference_chunk("doc1", 17, "Content 17", 14.0)
# Document 2
DOC_2_FILLER_1 = create_inference_chunk("doc2", 1, "Doc 2 Content 1", 15.0)
DOC_2_FILLER_2 = create_inference_chunk("doc2", 2, "Doc 2 Content 2", 16.0)
# This chunk below has top score for testing
DOC_2_TOP_CHUNK = create_inference_chunk("doc2", 3, "Doc 2 Content 3", 170.0)
DOC_2_FILLER_3 = create_inference_chunk("doc2", 4, "Doc 2 Content 4", 18.0)
DOC_2_BOTTOM_CHUNK = create_inference_chunk("doc2", 5, "Doc 2 Content 5", 19.0)
DOC_2_FILLER_4 = create_inference_chunk("doc2", 6, "Doc 2 Content 6", 20.0)
DOC_2_FILLER_5 = create_inference_chunk("doc2", 7, "Doc 2 Content 7", 21.0)
# Doc 2 has the highest score so it comes first
EXPECTED_CONTENT_1 = """
Doc 2 Content 1
Doc 2 Content 2
Doc 2 Content 3
Doc 2 Content 4
Doc 2 Content 5
Doc 2 Content 6
Doc 2 Content 7
""".strip()
EXPECTED_CONTENT_2 = """
Content 2
Content 3
Content 4
Content 5
Content 6
Content 7
Content 8
Content 9
Content 10
...
Content 13
Content 14
Content 15
Content 16
Content 17
""".strip()
@pytest.mark.parametrize(
"sections,expected_contents,expected_center_chunks",
[
(
# Sections
[
# Document 1, top/middle/bot connected + disconnected section
InferenceSection(
center_chunk=DOC_1_TOP_CHUNK,
chunks=[
DOC_1_FILLER_1,
DOC_1_FILLER_2,
DOC_1_TOP_CHUNK,
DOC_1_MID_CHUNK,
DOC_1_FILLER_3,
],
combined_content="N/A", # Not used
),
InferenceSection(
center_chunk=DOC_1_MID_CHUNK,
chunks=[
DOC_1_FILLER_2,
DOC_1_TOP_CHUNK,
DOC_1_MID_CHUNK,
DOC_1_FILLER_3,
DOC_1_FILLER_4,
],
combined_content="N/A",
),
InferenceSection(
center_chunk=DOC_1_BOTTOM_CHUNK,
chunks=[
DOC_1_FILLER_3,
DOC_1_FILLER_4,
DOC_1_BOTTOM_CHUNK,
DOC_1_FILLER_5,
DOC_1_FILLER_6,
],
combined_content="N/A",
),
InferenceSection(
center_chunk=DOC_1_DISCONNECTED,
chunks=[
DOC_1_FILLER_7,
DOC_1_FILLER_8,
DOC_1_DISCONNECTED,
DOC_1_FILLER_9,
DOC_1_FILLER_10,
],
combined_content="N/A",
),
InferenceSection(
center_chunk=DOC_2_TOP_CHUNK,
chunks=[
DOC_2_FILLER_1,
DOC_2_FILLER_2,
DOC_2_TOP_CHUNK,
DOC_2_FILLER_3,
DOC_2_BOTTOM_CHUNK,
],
combined_content="N/A",
),
InferenceSection(
center_chunk=DOC_2_BOTTOM_CHUNK,
chunks=[
DOC_2_TOP_CHUNK,
DOC_2_FILLER_3,
DOC_2_BOTTOM_CHUNK,
DOC_2_FILLER_4,
DOC_2_FILLER_5,
],
combined_content="N/A",
),
],
# Expected Content
[EXPECTED_CONTENT_1, EXPECTED_CONTENT_2],
# Expected Center Chunks (highest scores)
[DOC_2_TOP_CHUNK, DOC_1_BOTTOM_CHUNK],
),
],
)
def test_merge_sections(
sections: list[InferenceSection],
expected_contents: list[str],
expected_center_chunks: list[InferenceChunk],
) -> None:
sections.sort(key=lambda section: section.center_chunk.score or 0, reverse=True)
merged_sections = _merge_sections(sections)
assert merged_sections[0].combined_content == expected_contents[0]
assert merged_sections[1].combined_content == expected_contents[1]
assert merged_sections[0].center_chunk == expected_center_chunks[0]
assert merged_sections[1].center_chunk == expected_center_chunks[1]
@pytest.mark.parametrize(
"sections,expected_content,expected_center_chunk",
[
(
# Sections
[
InferenceSection(
center_chunk=DOC_1_TOP_CHUNK,
chunks=[DOC_1_TOP_CHUNK],
combined_content="N/A", # Not used
),
InferenceSection(
center_chunk=DOC_1_MID_CHUNK,
chunks=[DOC_1_MID_CHUNK],
combined_content="N/A",
),
],
# Expected Content
"Content 4\nContent 5",
# Expected Center Chunks (highest scores)
DOC_1_MID_CHUNK,
),
],
)
def test_merge_minimal_sections(
sections: list[InferenceSection],
expected_content: str,
expected_center_chunk: InferenceChunk,
) -> None:
sections.sort(key=lambda section: section.center_chunk.score or 0, reverse=True)
merged_sections = _merge_sections(sections)
assert merged_sections[0].combined_content == expected_content
assert merged_sections[0].center_chunk == expected_center_chunk