danswer/backend/onyx/chat/prune_and_merge.py
2024-12-13 09:56:10 -08:00

385 lines
14 KiB
Python

import json
from collections import defaultdict
from copy import deepcopy
from typing import TypeVar
from pydantic import BaseModel
from onyx.chat.models import ContextualPruningConfig
from onyx.chat.models import (
LlmDoc,
)
from onyx.chat.models import PromptConfig
from onyx.chat.prompt_builder.citations_prompt import compute_max_document_tokens
from onyx.configs.constants import IGNORE_FOR_QA
from onyx.configs.model_configs import DOC_EMBEDDING_CONTEXT_SIZE
from onyx.context.search.models import InferenceChunk
from onyx.context.search.models import InferenceSection
from onyx.llm.interfaces import LLMConfig
from onyx.natural_language_processing.utils import get_tokenizer
from onyx.natural_language_processing.utils import tokenizer_trim_content
from onyx.prompts.prompt_utils import build_doc_context_str
from onyx.tools.tool_implementations.search.search_utils import section_to_dict
from onyx.utils.logger import setup_logger
logger = setup_logger()
T = TypeVar("T", bound=LlmDoc | InferenceChunk | InferenceSection)
_METADATA_TOKEN_ESTIMATE = 75
# Title and additional tokens as part of the tool message json
# this is only used to log a warning so we can be more forgiving with the buffer
_OVERCOUNT_ESTIMATE = 256
class PruningError(Exception):
pass
class ChunkRange(BaseModel):
chunks: list[InferenceChunk]
start: int
end: int
def merge_chunk_intervals(chunk_ranges: list[ChunkRange]) -> list[ChunkRange]:
"""
This acts on a single document to merge the overlapping ranges of chunks
Algo explained here for easy understanding: https://leetcode.com/problems/merge-intervals
NOTE: this is used to merge chunk ranges for retrieving the right chunk_ids against the
document index, this does not merge the actual contents so it should not be used to actually
merge chunks post retrieval.
"""
sorted_ranges = sorted(chunk_ranges, key=lambda x: x.start)
combined_ranges: list[ChunkRange] = []
for new_chunk_range in sorted_ranges:
if not combined_ranges or combined_ranges[-1].end < new_chunk_range.start - 1:
combined_ranges.append(new_chunk_range)
else:
current_range = combined_ranges[-1]
current_range.end = max(current_range.end, new_chunk_range.end)
current_range.chunks.extend(new_chunk_range.chunks)
return combined_ranges
def _compute_limit(
prompt_config: PromptConfig,
llm_config: LLMConfig,
question: str,
max_chunks: int | None,
max_window_percentage: float | None,
max_tokens: int | None,
tool_token_count: int,
) -> int:
llm_max_document_tokens = compute_max_document_tokens(
prompt_config=prompt_config,
llm_config=llm_config,
tool_token_count=tool_token_count,
actual_user_input=question,
)
window_percentage_based_limit = (
max_window_percentage * llm_max_document_tokens
if max_window_percentage
else None
)
chunk_count_based_limit = (
max_chunks * DOC_EMBEDDING_CONTEXT_SIZE if max_chunks else None
)
limit_options = [
lim
for lim in [
window_percentage_based_limit,
chunk_count_based_limit,
max_tokens,
llm_max_document_tokens,
]
if lim
]
return int(min(limit_options))
def reorder_sections(
sections: list[InferenceSection],
section_relevance_list: list[bool] | None,
) -> list[InferenceSection]:
if section_relevance_list is None:
return sections
reordered_sections: list[InferenceSection] = []
if section_relevance_list is not None:
for selection_target in [True, False]:
for section, is_relevant in zip(sections, section_relevance_list):
if is_relevant == selection_target:
reordered_sections.append(section)
return reordered_sections
def _remove_sections_to_ignore(
sections: list[InferenceSection],
) -> list[InferenceSection]:
return [
section
for section in sections
if not section.center_chunk.metadata.get(IGNORE_FOR_QA)
]
def _apply_pruning(
sections: list[InferenceSection],
section_relevance_list: list[bool] | None,
token_limit: int,
is_manually_selected_docs: bool,
use_sections: bool,
using_tool_message: bool,
llm_config: LLMConfig,
) -> list[InferenceSection]:
llm_tokenizer = get_tokenizer(
provider_type=llm_config.model_provider,
model_name=llm_config.model_name,
)
sections = deepcopy(sections) # don't modify in place
# re-order docs with all the "relevant" docs at the front
sections = reorder_sections(
sections=sections, section_relevance_list=section_relevance_list
)
# remove docs that are explicitly marked as not for QA
sections = _remove_sections_to_ignore(sections=sections)
final_section_ind = None
total_tokens = 0
for ind, section in enumerate(sections):
section_str = (
# If using tool message, it will be a bit of an overestimate as the extra json text around the section
# will be counted towards the token count. However, once the Sections are merged, the extra json parts
# that overlap will not be counted multiple times like it is in the pruning step.
json.dumps(section_to_dict(section, ind))
if using_tool_message
else build_doc_context_str(
semantic_identifier=section.center_chunk.semantic_identifier,
source_type=section.center_chunk.source_type,
content=section.combined_content,
metadata_dict=section.center_chunk.metadata,
updated_at=section.center_chunk.updated_at,
ind=ind,
)
)
section_token_count = len(llm_tokenizer.encode(section_str))
# if not using sections (specifically, using Sections where each section maps exactly to the one center chunk),
# truncate chunks that are way too long. This can happen if the embedding model tokenizer is different
# than the LLM tokenizer
if (
not is_manually_selected_docs
and not use_sections
and section_token_count
> DOC_EMBEDDING_CONTEXT_SIZE + _METADATA_TOKEN_ESTIMATE
):
if (
section_token_count
> DOC_EMBEDDING_CONTEXT_SIZE
+ _METADATA_TOKEN_ESTIMATE
+ _OVERCOUNT_ESTIMATE
):
# If the section is just a little bit over, it is likely due to the additional tool message tokens
# no need to record this, the content will be trimmed just in case
logger.warning(
"Found more tokens in Section than expected, "
"likely mismatch between embedding and LLM tokenizers. Trimming content..."
)
section.combined_content = tokenizer_trim_content(
content=section.combined_content,
desired_length=DOC_EMBEDDING_CONTEXT_SIZE,
tokenizer=llm_tokenizer,
)
section_token_count = DOC_EMBEDDING_CONTEXT_SIZE
total_tokens += section_token_count
if total_tokens > token_limit:
final_section_ind = ind
break
if final_section_ind is not None:
if is_manually_selected_docs or use_sections:
if final_section_ind != len(sections) - 1:
# If using Sections, then the final section could be more than we need, in this case we are willing to
# truncate the final section to fit the specified context window
sections = sections[: final_section_ind + 1]
if is_manually_selected_docs:
# For document selection flow, only allow the final document/section to get truncated
# if more than that needs to be throw away then some documents are completely thrown away in which
# case this should be reported to the user as an error
raise PruningError(
"LLM context window exceeded. Please de-select some documents or shorten your query."
)
amount_to_truncate = total_tokens - token_limit
# NOTE: need to recalculate the length here, since the previous calculation included
# overhead from JSON-fying the doc / the metadata
final_doc_content_length = len(
llm_tokenizer.encode(sections[final_section_ind].combined_content)
) - (amount_to_truncate)
# this could occur if we only have space for the title / metadata
# not ideal, but it's the most reasonable thing to do
# NOTE: the frontend prevents documents from being selected if
# less than 75 tokens are available to try and avoid this situation
# from occurring in the first place
if final_doc_content_length <= 0:
logger.error(
f"Final section ({sections[final_section_ind].center_chunk.semantic_identifier}) content "
"length is less than 0. Removing this section from the final prompt."
)
sections.pop()
else:
sections[final_section_ind].combined_content = tokenizer_trim_content(
content=sections[final_section_ind].combined_content,
desired_length=final_doc_content_length,
tokenizer=llm_tokenizer,
)
else:
# For search on chunk level (Section is just a chunk), don't truncate the final Chunk/Section unless it's the only one
# If it's not the only one, we can throw it away, if it's the only one, we have to truncate
if final_section_ind != 0:
sections = sections[:final_section_ind]
else:
sections[0].combined_content = tokenizer_trim_content(
content=sections[0].combined_content,
desired_length=token_limit - _METADATA_TOKEN_ESTIMATE,
tokenizer=llm_tokenizer,
)
sections = [sections[0]]
return sections
def prune_sections(
sections: list[InferenceSection],
section_relevance_list: list[bool] | None,
prompt_config: PromptConfig,
llm_config: LLMConfig,
question: str,
contextual_pruning_config: ContextualPruningConfig,
) -> list[InferenceSection]:
# Assumes the sections are score ordered with highest first
if section_relevance_list is not None:
assert len(sections) == len(section_relevance_list)
actual_num_chunks = (
contextual_pruning_config.max_chunks
* contextual_pruning_config.num_chunk_multiple
if contextual_pruning_config.max_chunks
else None
)
token_limit = _compute_limit(
prompt_config=prompt_config,
llm_config=llm_config,
question=question,
max_chunks=actual_num_chunks,
max_window_percentage=contextual_pruning_config.max_window_percentage,
max_tokens=contextual_pruning_config.max_tokens,
tool_token_count=contextual_pruning_config.tool_num_tokens,
)
return _apply_pruning(
sections=sections,
section_relevance_list=section_relevance_list,
token_limit=token_limit,
is_manually_selected_docs=contextual_pruning_config.is_manually_selected_docs,
use_sections=contextual_pruning_config.use_sections, # Now default True
using_tool_message=contextual_pruning_config.using_tool_message,
llm_config=llm_config,
)
def _merge_doc_chunks(chunks: list[InferenceChunk]) -> InferenceSection:
# Assuming there are no duplicates by this point
sorted_chunks = sorted(chunks, key=lambda x: x.chunk_id)
center_chunk = max(
chunks, key=lambda x: x.score if x.score is not None else float("-inf")
)
merged_content = []
for i, chunk in enumerate(sorted_chunks):
if i > 0:
prev_chunk_id = sorted_chunks[i - 1].chunk_id
if chunk.chunk_id == prev_chunk_id + 1:
merged_content.append("\n")
else:
merged_content.append("\n\n...\n\n")
merged_content.append(chunk.content)
combined_content = "".join(merged_content)
return InferenceSection(
center_chunk=center_chunk,
chunks=sorted_chunks,
combined_content=combined_content,
)
def _merge_sections(sections: list[InferenceSection]) -> list[InferenceSection]:
docs_map: dict[str, dict[int, InferenceChunk]] = defaultdict(dict)
doc_order: dict[str, int] = {}
for index, section in enumerate(sections):
if section.center_chunk.document_id not in doc_order:
doc_order[section.center_chunk.document_id] = index
for chunk in [section.center_chunk] + section.chunks:
chunks_map = docs_map[section.center_chunk.document_id]
existing_chunk = chunks_map.get(chunk.chunk_id)
if (
existing_chunk is None
or existing_chunk.score is None
or chunk.score is not None
and chunk.score > existing_chunk.score
):
chunks_map[chunk.chunk_id] = chunk
new_sections = []
for section_chunks in docs_map.values():
new_sections.append(_merge_doc_chunks(chunks=list(section_chunks.values())))
# Sort by highest score, then by original document order
# It is now 1 large section per doc, the center chunk being the one with the highest score
new_sections.sort(
key=lambda x: (
x.center_chunk.score if x.center_chunk.score is not None else 0,
-1 * doc_order[x.center_chunk.document_id],
),
reverse=True,
)
return new_sections
def prune_and_merge_sections(
sections: list[InferenceSection],
section_relevance_list: list[bool] | None,
prompt_config: PromptConfig,
llm_config: LLMConfig,
question: str,
contextual_pruning_config: ContextualPruningConfig,
) -> list[InferenceSection]:
# Assumes the sections are score ordered with highest first
remaining_sections = prune_sections(
sections=sections,
section_relevance_list=section_relevance_list,
prompt_config=prompt_config,
llm_config=llm_config,
question=question,
contextual_pruning_config=contextual_pruning_config,
)
merged_sections = _merge_sections(sections=remaining_sections)
return merged_sections