Merge Sections Logic (#1801)

This commit is contained in:
Yuhong Sun 2024-07-10 20:14:02 -07:00 committed by GitHub
parent 47a550221f
commit 08c6e821e7
No known key found for this signature in database
GPG Key ID: B5690EEEBB952194
6 changed files with 358 additions and 34 deletions

View File

@ -83,6 +83,7 @@
"cwd": "${workspaceFolder}/backend", "cwd": "${workspaceFolder}/backend",
"envFile": "${workspaceFolder}/.env", "envFile": "${workspaceFolder}/.env",
"env": { "env": {
"LOG_DANSWER_MODEL_INTERACTIONS": "True",
"LOG_LEVEL": "DEBUG", "LOG_LEVEL": "DEBUG",
"PYTHONUNBUFFERED": "1", "PYTHONUNBUFFERED": "1",
"PYTHONPATH": "." "PYTHONPATH": "."
@ -105,6 +106,24 @@
"PYTHONUNBUFFERED": "1", "PYTHONUNBUFFERED": "1",
"PYTHONPATH": "." "PYTHONPATH": "."
} }
},
{
"name": "Pytest",
"type": "python",
"request": "launch",
"module": "pytest",
"cwd": "${workspaceFolder}/backend",
"envFile": "${workspaceFolder}/.env",
"env": {
"LOG_LEVEL": "DEBUG",
"PYTHONUNBUFFERED": "1",
"PYTHONPATH": "."
},
"args": [
"-v"
// Specify a sepcific module/test to run or provide nothing to run all tests
//"tests/unit/danswer/llm/answering/test_prune_and_merge.py"
]
} }
] ]
} }

View File

@ -28,6 +28,10 @@ BASE_RECENCY_DECAY = 0.5
FAVOR_RECENT_DECAY_MULTIPLIER = 2.0 FAVOR_RECENT_DECAY_MULTIPLIER = 2.0
# Currently this next one is not configurable via env # Currently this next one is not configurable via env
DISABLE_LLM_QUERY_ANSWERABILITY = QA_PROMPT_OVERRIDE == "weak" DISABLE_LLM_QUERY_ANSWERABILITY = QA_PROMPT_OVERRIDE == "weak"
# For the highest matching base size chunk, how many chunks above and below do we pull in by default
# Note this is not in any of the deployment configs yet
CONTEXT_CHUNKS_ABOVE = int(os.environ.get("CONTEXT_CHUNKS_ABOVE") or 0)
CONTEXT_CHUNKS_BELOW = int(os.environ.get("CONTEXT_CHUNKS_BELOW") or 0)
# Whether the LLM should evaluate all of the document chunks passed in for usefulness # Whether the LLM should evaluate all of the document chunks passed in for usefulness
# in relation to the user query # in relation to the user query
DISABLE_LLM_CHUNK_FILTER = ( DISABLE_LLM_CHUNK_FILTER = (
@ -43,8 +47,6 @@ DISABLE_LLM_QUERY_REPHRASE = (
# 1 edit per 20 characters, currently unused due to fuzzy match being too slow # 1 edit per 20 characters, currently unused due to fuzzy match being too slow
QUOTE_ALLOWED_ERROR_PERCENT = 0.05 QUOTE_ALLOWED_ERROR_PERCENT = 0.05
QA_TIMEOUT = int(os.environ.get("QA_TIMEOUT") or "60") # 60 seconds QA_TIMEOUT = int(os.environ.get("QA_TIMEOUT") or "60") # 60 seconds
# Include additional document/chunk metadata in prompt to GenerativeAI
INCLUDE_METADATA = False
# Keyword Search Drop Stopwords # Keyword Search Drop Stopwords
# If user has changed the default model, would most likely be to use a multilingual # If user has changed the default model, would most likely be to use a multilingual
# model, the stopwords are NLTK english stopwords so then we would want to not drop the keywords # model, the stopwords are NLTK english stopwords so then we would want to not drop the keywords

View File

@ -1,7 +1,10 @@
import json import json
from collections import defaultdict
from copy import deepcopy from copy import deepcopy
from typing import TypeVar from typing import TypeVar
from pydantic import BaseModel
from danswer.chat.models import ( from danswer.chat.models import (
LlmDoc, LlmDoc,
) )
@ -31,6 +34,36 @@ class PruningError(Exception):
pass pass
class ChunkRange(BaseModel):
chunks: list[InferenceChunk]
start: int
end: int
def merge_chunk_intervals(chunk_ranges: list[ChunkRange]) -> list[ChunkRange]:
"""
This acts on a single document to merge the overlapping ranges of chunks
Algo explained here for easy understanding: https://leetcode.com/problems/merge-intervals
NOTE: this is used to merge chunk ranges for retrieving the right chunk_ids against the
document index, this does not merge the actual contents so it should not be used to actually
merge chunks post retrieval.
"""
sorted_ranges = sorted(chunk_ranges, key=lambda x: x.start)
combined_ranges: list[ChunkRange] = []
for new_chunk_range in sorted_ranges:
if not combined_ranges or combined_ranges[-1].end < new_chunk_range.start - 1:
combined_ranges.append(new_chunk_range)
else:
current_range = combined_ranges[-1]
current_range.end = max(current_range.end, new_chunk_range.end)
current_range.chunks.extend(new_chunk_range.chunks)
return combined_ranges
def _compute_limit( def _compute_limit(
prompt_config: PromptConfig, prompt_config: PromptConfig,
llm_config: LLMConfig, llm_config: LLMConfig,
@ -219,6 +252,7 @@ def prune_sections(
question: str, question: str,
document_pruning_config: DocumentPruningConfig, document_pruning_config: DocumentPruningConfig,
) -> list[InferenceSection]: ) -> list[InferenceSection]:
# Assumes the sections are score ordered with highest first
if section_relevance_list is not None: if section_relevance_list is not None:
assert len(sections) == len(section_relevance_list) assert len(sections) == len(section_relevance_list)
@ -241,6 +275,67 @@ def prune_sections(
) )
def _merge_doc_chunks(chunks: list[InferenceChunk]) -> InferenceSection:
# Assuming there are no duplicates by this point
sorted_chunks = sorted(chunks, key=lambda x: x.chunk_id)
center_chunk = max(
chunks, key=lambda x: x.score if x.score is not None else float("-inf")
)
merged_content = []
for i, chunk in enumerate(sorted_chunks):
if i > 0:
prev_chunk_id = sorted_chunks[i - 1].chunk_id
if chunk.chunk_id == prev_chunk_id + 1:
merged_content.append("\n")
else:
merged_content.append("\n\n...\n\n")
merged_content.append(chunk.content)
combined_content = "".join(merged_content)
return InferenceSection(
center_chunk=center_chunk,
chunks=sorted_chunks,
combined_content=combined_content,
)
def _merge_sections(sections: list[InferenceSection]) -> list[InferenceSection]:
docs_map: dict[str, dict[int, InferenceChunk]] = defaultdict(dict)
doc_order: dict[str, int] = {}
for index, section in enumerate(sections):
if section.center_chunk.document_id not in doc_order:
doc_order[section.center_chunk.document_id] = index
for chunk in [section.center_chunk] + section.chunks:
chunks_map = docs_map[section.center_chunk.document_id]
existing_chunk = chunks_map.get(chunk.chunk_id)
if (
existing_chunk is None
or existing_chunk.score is None
or chunk.score is not None
and chunk.score > existing_chunk.score
):
chunks_map[chunk.chunk_id] = chunk
new_sections = []
for section_chunks in docs_map.values():
new_sections.append(_merge_doc_chunks(chunks=list(section_chunks.values())))
# Sort by highest score, then by original document order
# It is now 1 large section per doc, the center chunk being the one with the highest score
new_sections.sort(
key=lambda x: (
x.center_chunk.score if x.center_chunk.score is not None else 0,
-1 * doc_order[x.center_chunk.document_id],
),
reverse=True,
)
return new_sections
def prune_and_merge_sections( def prune_and_merge_sections(
sections: list[InferenceSection], sections: list[InferenceSection],
section_relevance_list: list[bool] | None, section_relevance_list: list[bool] | None,
@ -249,6 +344,7 @@ def prune_and_merge_sections(
question: str, question: str,
document_pruning_config: DocumentPruningConfig, document_pruning_config: DocumentPruningConfig,
) -> list[InferenceSection]: ) -> list[InferenceSection]:
# Assumes the sections are score ordered with highest first
remaining_sections = prune_sections( remaining_sections = prune_sections(
sections=sections, sections=sections,
section_relevance_list=section_relevance_list, section_relevance_list=section_relevance_list,
@ -257,6 +353,7 @@ def prune_and_merge_sections(
question=question, question=question,
document_pruning_config=document_pruning_config, document_pruning_config=document_pruning_config,
) )
# TODO add the actual section combination logic
return remaining_sections merged_sections = _merge_sections(sections=remaining_sections)
return merged_sections

View File

@ -4,6 +4,8 @@ from typing import Any
from pydantic import BaseModel from pydantic import BaseModel
from pydantic import validator from pydantic import validator
from danswer.configs.chat_configs import CONTEXT_CHUNKS_ABOVE
from danswer.configs.chat_configs import CONTEXT_CHUNKS_BELOW
from danswer.configs.chat_configs import DISABLE_LLM_CHUNK_FILTER from danswer.configs.chat_configs import DISABLE_LLM_CHUNK_FILTER
from danswer.configs.chat_configs import HYBRID_ALPHA from danswer.configs.chat_configs import HYBRID_ALPHA
from danswer.configs.chat_configs import NUM_RERANKED_RESULTS from danswer.configs.chat_configs import NUM_RERANKED_RESULTS
@ -47,8 +49,8 @@ class ChunkMetric(BaseModel):
class ChunkContext(BaseModel): class ChunkContext(BaseModel):
# Additional surrounding context options, if full doc, then chunks are deduped # Additional surrounding context options, if full doc, then chunks are deduped
# If surrounding context overlap, it is combined into one # If surrounding context overlap, it is combined into one
chunks_above: int = 0 chunks_above: int = CONTEXT_CHUNKS_ABOVE
chunks_below: int = 0 chunks_below: int = CONTEXT_CHUNKS_BELOW
full_doc: bool = False full_doc: bool = False
@validator("chunks_above", "chunks_below", pre=True, each_item=False) @validator("chunks_above", "chunks_below", pre=True, each_item=False)

View File

@ -3,13 +3,14 @@ from collections.abc import Callable
from collections.abc import Iterator from collections.abc import Iterator
from typing import cast from typing import cast
from pydantic import BaseModel
from sqlalchemy.orm import Session from sqlalchemy.orm import Session
from danswer.configs.chat_configs import MULTILINGUAL_QUERY_EXPANSION from danswer.configs.chat_configs import MULTILINGUAL_QUERY_EXPANSION
from danswer.db.embedding_model import get_current_db_embedding_model from danswer.db.embedding_model import get_current_db_embedding_model
from danswer.db.models import User from danswer.db.models import User
from danswer.document_index.factory import get_default_document_index from danswer.document_index.factory import get_default_document_index
from danswer.llm.answering.prune_and_merge import ChunkRange
from danswer.llm.answering.prune_and_merge import merge_chunk_intervals
from danswer.llm.interfaces import LLM from danswer.llm.interfaces import LLM
from danswer.search.enums import QueryFlow from danswer.search.enums import QueryFlow
from danswer.search.enums import SearchType from danswer.search.enums import SearchType
@ -30,32 +31,6 @@ from danswer.utils.threadpool_concurrency import run_functions_tuples_in_paralle
logger = setup_logger() logger = setup_logger()
class ChunkRange(BaseModel):
chunks: list[InferenceChunk]
start: int
end: int
combined_content: str | None = None
def merge_chunk_intervals(chunk_ranges: list[ChunkRange]) -> list[ChunkRange]:
"""This acts on a single document to merge the overlapping ranges of sections
Algo explained here for easy understanding: https://leetcode.com/problems/merge-intervals
"""
sorted_ranges = sorted(chunk_ranges, key=lambda x: x.start)
combined_ranges: list[ChunkRange] = []
for new_chunk_range in sorted_ranges:
if not combined_ranges or combined_ranges[-1].end < new_chunk_range.start:
combined_ranges.append(new_chunk_range)
else:
current_range = combined_ranges[-1]
current_range.end = max(current_range.end, new_chunk_range.end)
current_range.chunks.extend(new_chunk_range.chunks)
return combined_ranges
class SearchPipeline: class SearchPipeline:
def __init__( def __init__(
self, self,

View File

@ -0,0 +1,229 @@
import pytest
from danswer.configs.constants import DocumentSource
from danswer.llm.answering.prune_and_merge import _merge_sections
from danswer.search.models import InferenceChunk
from danswer.search.models import InferenceSection
# This large test accounts for all of the following:
# 1. Merging of adjacent sections
# 2. Merging of non-adjacent sections
# 3. Merging of sections where there are multiple documents
# 4. Verifying the contents of merged sections
# 5. Verifying the order/score of the merged sections
def create_inference_chunk(
document_id: str, chunk_id: int, content: str, score: float | None
) -> InferenceChunk:
"""
Create an InferenceChunk with hardcoded values for testing purposes.
"""
return InferenceChunk(
chunk_id=chunk_id,
document_id=document_id,
semantic_identifier=f"{document_id}_{chunk_id}",
blurb=f"{document_id}_{chunk_id}",
content=content,
source_links={0: "fake_link"},
section_continuation=False,
source_type=DocumentSource.WEB,
boost=0,
recency_bias=1.0,
score=score,
hidden=False,
metadata={},
match_highlights=[],
updated_at=None,
)
# Document 1, top connected sections
DOC_1_FILLER_1 = create_inference_chunk("doc1", 2, "Content 2", 1.0)
DOC_1_FILLER_2 = create_inference_chunk("doc1", 3, "Content 3", 2.0)
DOC_1_TOP_CHUNK = create_inference_chunk("doc1", 4, "Content 4", None)
DOC_1_MID_CHUNK = create_inference_chunk("doc1", 5, "Content 5", 4.0)
DOC_1_FILLER_3 = create_inference_chunk("doc1", 6, "Content 6", 5.0)
DOC_1_FILLER_4 = create_inference_chunk("doc1", 7, "Content 7", 6.0)
# This chunk below has the top score for testing
DOC_1_BOTTOM_CHUNK = create_inference_chunk("doc1", 8, "Content 8", 70.0)
DOC_1_FILLER_5 = create_inference_chunk("doc1", 9, "Content 9", None)
DOC_1_FILLER_6 = create_inference_chunk("doc1", 10, "Content 10", 9.0)
# Document 1, separate section
DOC_1_FILLER_7 = create_inference_chunk("doc1", 13, "Content 13", 10.0)
DOC_1_FILLER_8 = create_inference_chunk("doc1", 14, "Content 14", 11.0)
DOC_1_DISCONNECTED = create_inference_chunk("doc1", 15, "Content 15", 12.0)
DOC_1_FILLER_9 = create_inference_chunk("doc1", 16, "Content 16", 13.0)
DOC_1_FILLER_10 = create_inference_chunk("doc1", 17, "Content 17", 14.0)
# Document 2
DOC_2_FILLER_1 = create_inference_chunk("doc2", 1, "Doc 2 Content 1", 15.0)
DOC_2_FILLER_2 = create_inference_chunk("doc2", 2, "Doc 2 Content 2", 16.0)
# This chunk below has top score for testing
DOC_2_TOP_CHUNK = create_inference_chunk("doc2", 3, "Doc 2 Content 3", 170.0)
DOC_2_FILLER_3 = create_inference_chunk("doc2", 4, "Doc 2 Content 4", 18.0)
DOC_2_BOTTOM_CHUNK = create_inference_chunk("doc2", 5, "Doc 2 Content 5", 19.0)
DOC_2_FILLER_4 = create_inference_chunk("doc2", 6, "Doc 2 Content 6", 20.0)
DOC_2_FILLER_5 = create_inference_chunk("doc2", 7, "Doc 2 Content 7", 21.0)
# Doc 2 has the highest score so it comes first
EXPECTED_CONTENT_1 = """
Doc 2 Content 1
Doc 2 Content 2
Doc 2 Content 3
Doc 2 Content 4
Doc 2 Content 5
Doc 2 Content 6
Doc 2 Content 7
""".strip()
EXPECTED_CONTENT_2 = """
Content 2
Content 3
Content 4
Content 5
Content 6
Content 7
Content 8
Content 9
Content 10
...
Content 13
Content 14
Content 15
Content 16
Content 17
""".strip()
@pytest.mark.parametrize(
"sections,expected_contents,expected_center_chunks",
[
(
# Sections
[
# Document 1, top/middle/bot connected + disconnected section
InferenceSection(
center_chunk=DOC_1_TOP_CHUNK,
chunks=[
DOC_1_FILLER_1,
DOC_1_FILLER_2,
DOC_1_TOP_CHUNK,
DOC_1_MID_CHUNK,
DOC_1_FILLER_3,
],
combined_content="N/A", # Not used
),
InferenceSection(
center_chunk=DOC_1_MID_CHUNK,
chunks=[
DOC_1_FILLER_2,
DOC_1_TOP_CHUNK,
DOC_1_MID_CHUNK,
DOC_1_FILLER_3,
DOC_1_FILLER_4,
],
combined_content="N/A",
),
InferenceSection(
center_chunk=DOC_1_BOTTOM_CHUNK,
chunks=[
DOC_1_FILLER_3,
DOC_1_FILLER_4,
DOC_1_BOTTOM_CHUNK,
DOC_1_FILLER_5,
DOC_1_FILLER_6,
],
combined_content="N/A",
),
InferenceSection(
center_chunk=DOC_1_DISCONNECTED,
chunks=[
DOC_1_FILLER_7,
DOC_1_FILLER_8,
DOC_1_DISCONNECTED,
DOC_1_FILLER_9,
DOC_1_FILLER_10,
],
combined_content="N/A",
),
InferenceSection(
center_chunk=DOC_2_TOP_CHUNK,
chunks=[
DOC_2_FILLER_1,
DOC_2_FILLER_2,
DOC_2_TOP_CHUNK,
DOC_2_FILLER_3,
DOC_2_BOTTOM_CHUNK,
],
combined_content="N/A",
),
InferenceSection(
center_chunk=DOC_2_BOTTOM_CHUNK,
chunks=[
DOC_2_TOP_CHUNK,
DOC_2_FILLER_3,
DOC_2_BOTTOM_CHUNK,
DOC_2_FILLER_4,
DOC_2_FILLER_5,
],
combined_content="N/A",
),
],
# Expected Content
[EXPECTED_CONTENT_1, EXPECTED_CONTENT_2],
# Expected Center Chunks (highest scores)
[DOC_2_TOP_CHUNK, DOC_1_BOTTOM_CHUNK],
),
],
)
def test_merge_sections(
sections: list[InferenceSection],
expected_contents: list[str],
expected_center_chunks: list[InferenceChunk],
) -> None:
sections.sort(key=lambda section: section.center_chunk.score or 0, reverse=True)
merged_sections = _merge_sections(sections)
assert merged_sections[0].combined_content == expected_contents[0]
assert merged_sections[1].combined_content == expected_contents[1]
assert merged_sections[0].center_chunk == expected_center_chunks[0]
assert merged_sections[1].center_chunk == expected_center_chunks[1]
@pytest.mark.parametrize(
"sections,expected_content,expected_center_chunk",
[
(
# Sections
[
InferenceSection(
center_chunk=DOC_1_TOP_CHUNK,
chunks=[DOC_1_TOP_CHUNK],
combined_content="N/A", # Not used
),
InferenceSection(
center_chunk=DOC_1_MID_CHUNK,
chunks=[DOC_1_MID_CHUNK],
combined_content="N/A",
),
],
# Expected Content
"Content 4\nContent 5",
# Expected Center Chunks (highest scores)
DOC_1_MID_CHUNK,
),
],
)
def test_merge_minimal_sections(
sections: list[InferenceSection],
expected_content: str,
expected_center_chunk: InferenceChunk,
) -> None:
sections.sort(key=lambda section: section.center_chunk.score or 0, reverse=True)
merged_sections = _merge_sections(sections)
assert merged_sections[0].combined_content == expected_content
assert merged_sections[0].center_chunk == expected_center_chunk