From 08c6e821e7a2a80afac4d93f412c4c9900d0ff1c Mon Sep 17 00:00:00 2001 From: Yuhong Sun Date: Wed, 10 Jul 2024 20:14:02 -0700 Subject: [PATCH] Merge Sections Logic (#1801) --- .vscode/launch.template.jsonc | 21 +- backend/danswer/configs/chat_configs.py | 6 +- .../danswer/llm/answering/prune_and_merge.py | 101 +++++++- backend/danswer/search/models.py | 6 +- backend/danswer/search/pipeline.py | 29 +-- .../llm/answering/test_prune_and_merge.py | 229 ++++++++++++++++++ 6 files changed, 358 insertions(+), 34 deletions(-) create mode 100644 backend/tests/unit/danswer/llm/answering/test_prune_and_merge.py diff --git a/.vscode/launch.template.jsonc b/.vscode/launch.template.jsonc index 19bfc513b..a4be80fc1 100644 --- a/.vscode/launch.template.jsonc +++ b/.vscode/launch.template.jsonc @@ -83,6 +83,7 @@ "cwd": "${workspaceFolder}/backend", "envFile": "${workspaceFolder}/.env", "env": { + "LOG_DANSWER_MODEL_INTERACTIONS": "True", "LOG_LEVEL": "DEBUG", "PYTHONUNBUFFERED": "1", "PYTHONPATH": "." @@ -105,6 +106,24 @@ "PYTHONUNBUFFERED": "1", "PYTHONPATH": "." } + }, + { + "name": "Pytest", + "type": "python", + "request": "launch", + "module": "pytest", + "cwd": "${workspaceFolder}/backend", + "envFile": "${workspaceFolder}/.env", + "env": { + "LOG_LEVEL": "DEBUG", + "PYTHONUNBUFFERED": "1", + "PYTHONPATH": "." + }, + "args": [ + "-v" + // Specify a sepcific module/test to run or provide nothing to run all tests + //"tests/unit/danswer/llm/answering/test_prune_and_merge.py" + ] } ] -} \ No newline at end of file +} diff --git a/backend/danswer/configs/chat_configs.py b/backend/danswer/configs/chat_configs.py index 19b7bb756..3595480e2 100644 --- a/backend/danswer/configs/chat_configs.py +++ b/backend/danswer/configs/chat_configs.py @@ -28,6 +28,10 @@ BASE_RECENCY_DECAY = 0.5 FAVOR_RECENT_DECAY_MULTIPLIER = 2.0 # Currently this next one is not configurable via env DISABLE_LLM_QUERY_ANSWERABILITY = QA_PROMPT_OVERRIDE == "weak" +# For the highest matching base size chunk, how many chunks above and below do we pull in by default +# Note this is not in any of the deployment configs yet +CONTEXT_CHUNKS_ABOVE = int(os.environ.get("CONTEXT_CHUNKS_ABOVE") or 0) +CONTEXT_CHUNKS_BELOW = int(os.environ.get("CONTEXT_CHUNKS_BELOW") or 0) # Whether the LLM should evaluate all of the document chunks passed in for usefulness # in relation to the user query DISABLE_LLM_CHUNK_FILTER = ( @@ -43,8 +47,6 @@ DISABLE_LLM_QUERY_REPHRASE = ( # 1 edit per 20 characters, currently unused due to fuzzy match being too slow QUOTE_ALLOWED_ERROR_PERCENT = 0.05 QA_TIMEOUT = int(os.environ.get("QA_TIMEOUT") or "60") # 60 seconds -# Include additional document/chunk metadata in prompt to GenerativeAI -INCLUDE_METADATA = False # Keyword Search Drop Stopwords # If user has changed the default model, would most likely be to use a multilingual # model, the stopwords are NLTK english stopwords so then we would want to not drop the keywords diff --git a/backend/danswer/llm/answering/prune_and_merge.py b/backend/danswer/llm/answering/prune_and_merge.py index 8747cdde7..3fee5266d 100644 --- a/backend/danswer/llm/answering/prune_and_merge.py +++ b/backend/danswer/llm/answering/prune_and_merge.py @@ -1,7 +1,10 @@ import json +from collections import defaultdict from copy import deepcopy from typing import TypeVar +from pydantic import BaseModel + from danswer.chat.models import ( LlmDoc, ) @@ -31,6 +34,36 @@ class PruningError(Exception): pass +class ChunkRange(BaseModel): + chunks: list[InferenceChunk] + start: int + end: int + + +def merge_chunk_intervals(chunk_ranges: list[ChunkRange]) -> list[ChunkRange]: + """ + This acts on a single document to merge the overlapping ranges of chunks + Algo explained here for easy understanding: https://leetcode.com/problems/merge-intervals + + NOTE: this is used to merge chunk ranges for retrieving the right chunk_ids against the + document index, this does not merge the actual contents so it should not be used to actually + merge chunks post retrieval. + """ + sorted_ranges = sorted(chunk_ranges, key=lambda x: x.start) + + combined_ranges: list[ChunkRange] = [] + + for new_chunk_range in sorted_ranges: + if not combined_ranges or combined_ranges[-1].end < new_chunk_range.start - 1: + combined_ranges.append(new_chunk_range) + else: + current_range = combined_ranges[-1] + current_range.end = max(current_range.end, new_chunk_range.end) + current_range.chunks.extend(new_chunk_range.chunks) + + return combined_ranges + + def _compute_limit( prompt_config: PromptConfig, llm_config: LLMConfig, @@ -219,6 +252,7 @@ def prune_sections( question: str, document_pruning_config: DocumentPruningConfig, ) -> list[InferenceSection]: + # Assumes the sections are score ordered with highest first if section_relevance_list is not None: assert len(sections) == len(section_relevance_list) @@ -241,6 +275,67 @@ def prune_sections( ) +def _merge_doc_chunks(chunks: list[InferenceChunk]) -> InferenceSection: + # Assuming there are no duplicates by this point + sorted_chunks = sorted(chunks, key=lambda x: x.chunk_id) + + center_chunk = max( + chunks, key=lambda x: x.score if x.score is not None else float("-inf") + ) + + merged_content = [] + for i, chunk in enumerate(sorted_chunks): + if i > 0: + prev_chunk_id = sorted_chunks[i - 1].chunk_id + if chunk.chunk_id == prev_chunk_id + 1: + merged_content.append("\n") + else: + merged_content.append("\n\n...\n\n") + merged_content.append(chunk.content) + + combined_content = "".join(merged_content) + + return InferenceSection( + center_chunk=center_chunk, + chunks=sorted_chunks, + combined_content=combined_content, + ) + + +def _merge_sections(sections: list[InferenceSection]) -> list[InferenceSection]: + docs_map: dict[str, dict[int, InferenceChunk]] = defaultdict(dict) + doc_order: dict[str, int] = {} + for index, section in enumerate(sections): + if section.center_chunk.document_id not in doc_order: + doc_order[section.center_chunk.document_id] = index + for chunk in [section.center_chunk] + section.chunks: + chunks_map = docs_map[section.center_chunk.document_id] + existing_chunk = chunks_map.get(chunk.chunk_id) + if ( + existing_chunk is None + or existing_chunk.score is None + or chunk.score is not None + and chunk.score > existing_chunk.score + ): + chunks_map[chunk.chunk_id] = chunk + + new_sections = [] + for section_chunks in docs_map.values(): + new_sections.append(_merge_doc_chunks(chunks=list(section_chunks.values()))) + + # Sort by highest score, then by original document order + # It is now 1 large section per doc, the center chunk being the one with the highest score + new_sections.sort( + key=lambda x: ( + x.center_chunk.score if x.center_chunk.score is not None else 0, + -1 * doc_order[x.center_chunk.document_id], + ), + reverse=True, + ) + + return new_sections + + def prune_and_merge_sections( sections: list[InferenceSection], section_relevance_list: list[bool] | None, @@ -249,6 +344,7 @@ def prune_and_merge_sections( question: str, document_pruning_config: DocumentPruningConfig, ) -> list[InferenceSection]: + # Assumes the sections are score ordered with highest first remaining_sections = prune_sections( sections=sections, section_relevance_list=section_relevance_list, @@ -257,6 +353,7 @@ def prune_and_merge_sections( question=question, document_pruning_config=document_pruning_config, ) - # TODO add the actual section combination logic - return remaining_sections + merged_sections = _merge_sections(sections=remaining_sections) + + return merged_sections diff --git a/backend/danswer/search/models.py b/backend/danswer/search/models.py index ada16ef06..6e16de2c7 100644 --- a/backend/danswer/search/models.py +++ b/backend/danswer/search/models.py @@ -4,6 +4,8 @@ from typing import Any from pydantic import BaseModel from pydantic import validator +from danswer.configs.chat_configs import CONTEXT_CHUNKS_ABOVE +from danswer.configs.chat_configs import CONTEXT_CHUNKS_BELOW from danswer.configs.chat_configs import DISABLE_LLM_CHUNK_FILTER from danswer.configs.chat_configs import HYBRID_ALPHA from danswer.configs.chat_configs import NUM_RERANKED_RESULTS @@ -47,8 +49,8 @@ class ChunkMetric(BaseModel): class ChunkContext(BaseModel): # Additional surrounding context options, if full doc, then chunks are deduped # If surrounding context overlap, it is combined into one - chunks_above: int = 0 - chunks_below: int = 0 + chunks_above: int = CONTEXT_CHUNKS_ABOVE + chunks_below: int = CONTEXT_CHUNKS_BELOW full_doc: bool = False @validator("chunks_above", "chunks_below", pre=True, each_item=False) diff --git a/backend/danswer/search/pipeline.py b/backend/danswer/search/pipeline.py index 10381acc7..2d990c15a 100644 --- a/backend/danswer/search/pipeline.py +++ b/backend/danswer/search/pipeline.py @@ -3,13 +3,14 @@ from collections.abc import Callable from collections.abc import Iterator from typing import cast -from pydantic import BaseModel from sqlalchemy.orm import Session from danswer.configs.chat_configs import MULTILINGUAL_QUERY_EXPANSION from danswer.db.embedding_model import get_current_db_embedding_model from danswer.db.models import User from danswer.document_index.factory import get_default_document_index +from danswer.llm.answering.prune_and_merge import ChunkRange +from danswer.llm.answering.prune_and_merge import merge_chunk_intervals from danswer.llm.interfaces import LLM from danswer.search.enums import QueryFlow from danswer.search.enums import SearchType @@ -30,32 +31,6 @@ from danswer.utils.threadpool_concurrency import run_functions_tuples_in_paralle logger = setup_logger() -class ChunkRange(BaseModel): - chunks: list[InferenceChunk] - start: int - end: int - combined_content: str | None = None - - -def merge_chunk_intervals(chunk_ranges: list[ChunkRange]) -> list[ChunkRange]: - """This acts on a single document to merge the overlapping ranges of sections - Algo explained here for easy understanding: https://leetcode.com/problems/merge-intervals - """ - sorted_ranges = sorted(chunk_ranges, key=lambda x: x.start) - - combined_ranges: list[ChunkRange] = [] - - for new_chunk_range in sorted_ranges: - if not combined_ranges or combined_ranges[-1].end < new_chunk_range.start: - combined_ranges.append(new_chunk_range) - else: - current_range = combined_ranges[-1] - current_range.end = max(current_range.end, new_chunk_range.end) - current_range.chunks.extend(new_chunk_range.chunks) - - return combined_ranges - - class SearchPipeline: def __init__( self, diff --git a/backend/tests/unit/danswer/llm/answering/test_prune_and_merge.py b/backend/tests/unit/danswer/llm/answering/test_prune_and_merge.py new file mode 100644 index 000000000..1782f3edb --- /dev/null +++ b/backend/tests/unit/danswer/llm/answering/test_prune_and_merge.py @@ -0,0 +1,229 @@ +import pytest + +from danswer.configs.constants import DocumentSource +from danswer.llm.answering.prune_and_merge import _merge_sections +from danswer.search.models import InferenceChunk +from danswer.search.models import InferenceSection + + +# This large test accounts for all of the following: +# 1. Merging of adjacent sections +# 2. Merging of non-adjacent sections +# 3. Merging of sections where there are multiple documents +# 4. Verifying the contents of merged sections +# 5. Verifying the order/score of the merged sections + + +def create_inference_chunk( + document_id: str, chunk_id: int, content: str, score: float | None +) -> InferenceChunk: + """ + Create an InferenceChunk with hardcoded values for testing purposes. + """ + return InferenceChunk( + chunk_id=chunk_id, + document_id=document_id, + semantic_identifier=f"{document_id}_{chunk_id}", + blurb=f"{document_id}_{chunk_id}", + content=content, + source_links={0: "fake_link"}, + section_continuation=False, + source_type=DocumentSource.WEB, + boost=0, + recency_bias=1.0, + score=score, + hidden=False, + metadata={}, + match_highlights=[], + updated_at=None, + ) + + +# Document 1, top connected sections +DOC_1_FILLER_1 = create_inference_chunk("doc1", 2, "Content 2", 1.0) +DOC_1_FILLER_2 = create_inference_chunk("doc1", 3, "Content 3", 2.0) +DOC_1_TOP_CHUNK = create_inference_chunk("doc1", 4, "Content 4", None) +DOC_1_MID_CHUNK = create_inference_chunk("doc1", 5, "Content 5", 4.0) +DOC_1_FILLER_3 = create_inference_chunk("doc1", 6, "Content 6", 5.0) +DOC_1_FILLER_4 = create_inference_chunk("doc1", 7, "Content 7", 6.0) +# This chunk below has the top score for testing +DOC_1_BOTTOM_CHUNK = create_inference_chunk("doc1", 8, "Content 8", 70.0) +DOC_1_FILLER_5 = create_inference_chunk("doc1", 9, "Content 9", None) +DOC_1_FILLER_6 = create_inference_chunk("doc1", 10, "Content 10", 9.0) +# Document 1, separate section +DOC_1_FILLER_7 = create_inference_chunk("doc1", 13, "Content 13", 10.0) +DOC_1_FILLER_8 = create_inference_chunk("doc1", 14, "Content 14", 11.0) +DOC_1_DISCONNECTED = create_inference_chunk("doc1", 15, "Content 15", 12.0) +DOC_1_FILLER_9 = create_inference_chunk("doc1", 16, "Content 16", 13.0) +DOC_1_FILLER_10 = create_inference_chunk("doc1", 17, "Content 17", 14.0) +# Document 2 +DOC_2_FILLER_1 = create_inference_chunk("doc2", 1, "Doc 2 Content 1", 15.0) +DOC_2_FILLER_2 = create_inference_chunk("doc2", 2, "Doc 2 Content 2", 16.0) +# This chunk below has top score for testing +DOC_2_TOP_CHUNK = create_inference_chunk("doc2", 3, "Doc 2 Content 3", 170.0) +DOC_2_FILLER_3 = create_inference_chunk("doc2", 4, "Doc 2 Content 4", 18.0) +DOC_2_BOTTOM_CHUNK = create_inference_chunk("doc2", 5, "Doc 2 Content 5", 19.0) +DOC_2_FILLER_4 = create_inference_chunk("doc2", 6, "Doc 2 Content 6", 20.0) +DOC_2_FILLER_5 = create_inference_chunk("doc2", 7, "Doc 2 Content 7", 21.0) + + +# Doc 2 has the highest score so it comes first +EXPECTED_CONTENT_1 = """ +Doc 2 Content 1 +Doc 2 Content 2 +Doc 2 Content 3 +Doc 2 Content 4 +Doc 2 Content 5 +Doc 2 Content 6 +Doc 2 Content 7 +""".strip() + + +EXPECTED_CONTENT_2 = """ +Content 2 +Content 3 +Content 4 +Content 5 +Content 6 +Content 7 +Content 8 +Content 9 +Content 10 + +... + +Content 13 +Content 14 +Content 15 +Content 16 +Content 17 +""".strip() + + +@pytest.mark.parametrize( + "sections,expected_contents,expected_center_chunks", + [ + ( + # Sections + [ + # Document 1, top/middle/bot connected + disconnected section + InferenceSection( + center_chunk=DOC_1_TOP_CHUNK, + chunks=[ + DOC_1_FILLER_1, + DOC_1_FILLER_2, + DOC_1_TOP_CHUNK, + DOC_1_MID_CHUNK, + DOC_1_FILLER_3, + ], + combined_content="N/A", # Not used + ), + InferenceSection( + center_chunk=DOC_1_MID_CHUNK, + chunks=[ + DOC_1_FILLER_2, + DOC_1_TOP_CHUNK, + DOC_1_MID_CHUNK, + DOC_1_FILLER_3, + DOC_1_FILLER_4, + ], + combined_content="N/A", + ), + InferenceSection( + center_chunk=DOC_1_BOTTOM_CHUNK, + chunks=[ + DOC_1_FILLER_3, + DOC_1_FILLER_4, + DOC_1_BOTTOM_CHUNK, + DOC_1_FILLER_5, + DOC_1_FILLER_6, + ], + combined_content="N/A", + ), + InferenceSection( + center_chunk=DOC_1_DISCONNECTED, + chunks=[ + DOC_1_FILLER_7, + DOC_1_FILLER_8, + DOC_1_DISCONNECTED, + DOC_1_FILLER_9, + DOC_1_FILLER_10, + ], + combined_content="N/A", + ), + InferenceSection( + center_chunk=DOC_2_TOP_CHUNK, + chunks=[ + DOC_2_FILLER_1, + DOC_2_FILLER_2, + DOC_2_TOP_CHUNK, + DOC_2_FILLER_3, + DOC_2_BOTTOM_CHUNK, + ], + combined_content="N/A", + ), + InferenceSection( + center_chunk=DOC_2_BOTTOM_CHUNK, + chunks=[ + DOC_2_TOP_CHUNK, + DOC_2_FILLER_3, + DOC_2_BOTTOM_CHUNK, + DOC_2_FILLER_4, + DOC_2_FILLER_5, + ], + combined_content="N/A", + ), + ], + # Expected Content + [EXPECTED_CONTENT_1, EXPECTED_CONTENT_2], + # Expected Center Chunks (highest scores) + [DOC_2_TOP_CHUNK, DOC_1_BOTTOM_CHUNK], + ), + ], +) +def test_merge_sections( + sections: list[InferenceSection], + expected_contents: list[str], + expected_center_chunks: list[InferenceChunk], +) -> None: + sections.sort(key=lambda section: section.center_chunk.score or 0, reverse=True) + merged_sections = _merge_sections(sections) + assert merged_sections[0].combined_content == expected_contents[0] + assert merged_sections[1].combined_content == expected_contents[1] + assert merged_sections[0].center_chunk == expected_center_chunks[0] + assert merged_sections[1].center_chunk == expected_center_chunks[1] + + +@pytest.mark.parametrize( + "sections,expected_content,expected_center_chunk", + [ + ( + # Sections + [ + InferenceSection( + center_chunk=DOC_1_TOP_CHUNK, + chunks=[DOC_1_TOP_CHUNK], + combined_content="N/A", # Not used + ), + InferenceSection( + center_chunk=DOC_1_MID_CHUNK, + chunks=[DOC_1_MID_CHUNK], + combined_content="N/A", + ), + ], + # Expected Content + "Content 4\nContent 5", + # Expected Center Chunks (highest scores) + DOC_1_MID_CHUNK, + ), + ], +) +def test_merge_minimal_sections( + sections: list[InferenceSection], + expected_content: str, + expected_center_chunk: InferenceChunk, +) -> None: + sections.sort(key=lambda section: section.center_chunk.score or 0, reverse=True) + merged_sections = _merge_sections(sections) + assert merged_sections[0].combined_content == expected_content + assert merged_sections[0].center_chunk == expected_center_chunk