Provide Additional Context for Chunk Options in APIs (#1330)

This commit is contained in:
Yuhong Sun 2024-04-14 18:32:22 -07:00 committed by GitHub
parent b9b1e22fac
commit a17060af5a
No known key found for this signature in database
GPG Key ID: B5690EEEBB952194
31 changed files with 421 additions and 147 deletions

View File

@ -7,16 +7,19 @@ from danswer.chat.models import CitationInfo
from danswer.chat.models import LlmDoc
from danswer.db.chat import get_chat_messages_by_session
from danswer.db.models import ChatMessage
from danswer.indexing.models import InferenceChunk
from danswer.search.models import InferenceChunk
from danswer.search.models import InferenceSection
from danswer.utils.logger import setup_logger
logger = setup_logger()
def llm_doc_from_inference_chunk(inf_chunk: InferenceChunk) -> LlmDoc:
def llm_doc_from_inference_section(inf_chunk: InferenceSection) -> LlmDoc:
return LlmDoc(
document_id=inf_chunk.document_id,
content=inf_chunk.content,
# This one is using the combined content of all the chunks of the section
# In default settings, this is the same as just the content of base chunk
content=inf_chunk.combined_content,
blurb=inf_chunk.blurb,
semantic_identifier=inf_chunk.semantic_identifier,
source_type=inf_chunk.source_type,

View File

@ -6,7 +6,7 @@ from typing import cast
from sqlalchemy.orm import Session
from danswer.chat.chat_utils import create_chat_chain
from danswer.chat.chat_utils import llm_doc_from_inference_chunk
from danswer.chat.chat_utils import llm_doc_from_inference_section
from danswer.chat.models import CitationInfo
from danswer.chat.models import DanswerAnswerPiece
from danswer.chat.models import LlmDoc
@ -44,7 +44,7 @@ from danswer.search.models import OptionalSearchSetting
from danswer.search.models import SearchRequest
from danswer.search.pipeline import SearchPipeline
from danswer.search.retrieval.search_runner import inference_documents_from_ids
from danswer.search.utils import chunks_to_search_docs
from danswer.search.utils import chunks_or_sections_to_search_docs
from danswer.secondary_llm_flows.choose_search import check_if_need_search
from danswer.secondary_llm_flows.query_expansion import history_based_query_rephrase
from danswer.server.query_and_chat.models import ChatMessageDetail
@ -216,6 +216,7 @@ def stream_chat_message_objects(
)
rephrased_query = None
llm_relevance_list = None
if reference_doc_ids:
identifier_tuples = get_doc_query_identifiers_from_model(
search_doc_ids=reference_doc_ids,
@ -263,13 +264,16 @@ def stream_chat_message_objects(
persona=persona,
offset=retrieval_options.offset if retrieval_options else None,
limit=retrieval_options.limit if retrieval_options else None,
chunks_above=new_msg_req.chunks_above,
chunks_below=new_msg_req.chunks_below,
full_doc=new_msg_req.full_doc,
),
user=user,
db_session=db_session,
)
top_chunks = search_pipeline.reranked_docs
top_docs = chunks_to_search_docs(top_chunks)
top_sections = search_pipeline.reranked_sections
top_docs = chunks_or_sections_to_search_docs(top_sections)
reference_db_search_docs = [
create_db_search_doc(server_search_doc=top_doc, db_session=db_session)
@ -294,7 +298,7 @@ def stream_chat_message_objects(
# Yield the list of LLM selected chunks for showing the LLM selected icons in the UI
llm_relevance_filtering_response = LLMRelevanceFilterResponse(
relevant_chunk_indices=search_pipeline.relevant_chunk_indicies
relevant_chunk_indices=search_pipeline.relevant_chunk_indices
)
yield llm_relevance_filtering_response
@ -305,9 +309,13 @@ def stream_chat_message_objects(
else default_num_chunks
),
max_window_percentage=max_document_percentage,
use_sections=search_pipeline.ran_merge_chunk,
)
llm_docs = [llm_doc_from_inference_chunk(chunk) for chunk in top_chunks]
llm_docs = [
llm_doc_from_inference_section(section) for section in top_sections
]
llm_relevance_list = search_pipeline.section_relevance_list
else:
llm_docs = []
@ -369,6 +377,7 @@ def stream_chat_message_objects(
persona,
llm_override=(new_msg_req.llm_override or chat_session.llm_override),
),
doc_relevance_list=llm_relevance_list,
message_history=[
PreviousMessage.from_chat_message(msg) for msg in history_msgs
],

View File

@ -6,7 +6,7 @@ from sqlalchemy.orm import Session
from danswer.db.embedding_model import get_current_db_embedding_model
from danswer.db.embedding_model import get_secondary_db_embedding_model
from danswer.indexing.models import IndexChunk
from danswer.indexing.models import InferenceChunk
from danswer.search.models import InferenceChunk
DEFAULT_BATCH_SIZE = 30

View File

@ -5,8 +5,8 @@ from typing import Any
from danswer.access.models import DocumentAccess
from danswer.indexing.models import DocMetadataAwareIndexChunk
from danswer.indexing.models import InferenceChunk
from danswer.search.models import IndexFilters
from danswer.search.models import InferenceChunk
@dataclass(frozen=True)
@ -183,7 +183,8 @@ class IdRetrievalCapable(abc.ABC):
def id_based_retrieval(
self,
document_id: str,
chunk_ind: int | None,
min_chunk_ind: int | None,
max_chunk_ind: int | None,
filters: IndexFilters,
) -> list[InferenceChunk]:
"""
@ -196,7 +197,8 @@ class IdRetrievalCapable(abc.ABC):
Parameters:
- document_id: document id for which to retrieve the chunk(s)
- chunk_ind: chunk index to return, if None, return all of the chunks in order
- min_chunk_ind: if None then fetch from the start of doc
- max_chunk_ind:
- filters: standard filters object, in this case only the access filter is applied as a
permission check

View File

@ -62,8 +62,8 @@ from danswer.document_index.interfaces import DocumentInsertionRecord
from danswer.document_index.interfaces import UpdateRequest
from danswer.document_index.vespa.utils import remove_invalid_unicode_chars
from danswer.indexing.models import DocMetadataAwareIndexChunk
from danswer.indexing.models import InferenceChunk
from danswer.search.models import IndexFilters
from danswer.search.models import InferenceChunk
from danswer.search.retrieval.search_runner import query_processing
from danswer.search.retrieval.search_runner import remove_stop_words_and_punctuation
from danswer.utils.batching import batch_generator
@ -864,10 +864,11 @@ class VespaIndex(DocumentIndex):
def id_based_retrieval(
self,
document_id: str,
chunk_ind: int | None,
min_chunk_ind: int | None,
max_chunk_ind: int | None,
filters: IndexFilters,
) -> list[InferenceChunk]:
if chunk_ind is None:
if min_chunk_ind is None and max_chunk_ind is None:
vespa_chunk_ids = _get_vespa_chunk_ids_by_document_id(
document_id=document_id,
index_name=self.index_name,
@ -888,14 +889,22 @@ class VespaIndex(DocumentIndex):
inference_chunks.sort(key=lambda chunk: chunk.chunk_id)
return inference_chunks
else:
filters_str = _build_vespa_filters(filters=filters, include_hidden=True)
yql = (
VespaIndex.yql_base.format(index_name=self.index_name)
+ filters_str
+ f"({DOCUMENT_ID} contains '{document_id}' and {CHUNK_ID} contains '{chunk_ind}')"
)
return _query_vespa({"yql": yql})
filters_str = _build_vespa_filters(filters=filters, include_hidden=True)
yql = (
VespaIndex.yql_base.format(index_name=self.index_name)
+ filters_str
+ f"({DOCUMENT_ID} contains '{document_id}'"
)
if min_chunk_ind is not None:
yql += f" and {min_chunk_ind} <= {CHUNK_ID}"
if max_chunk_ind is not None:
yql += f" and {max_chunk_ind} >= {CHUNK_ID}"
yql = yql + ")"
inference_chunks = _query_vespa({"yql": yql})
inference_chunks.sort(key=lambda chunk: chunk.chunk_id)
return inference_chunks
def keyword_retrieval(
self,

View File

@ -149,7 +149,7 @@ class DefaultIndexingEmbedder(IndexingEmbedder):
title_embed_dict[title] = title_embedding
new_embedded_chunk = IndexChunk(
**{k: getattr(chunk, k) for k in chunk.__dataclass_fields__},
**chunk.dict(),
embeddings=ChunkEmbedding(
full_embedding=chunk_embeddings[0],
mini_chunk_embeddings=chunk_embeddings[1:],

View File

@ -1,12 +1,8 @@
from dataclasses import dataclass
from dataclasses import fields
from datetime import datetime
from typing import TYPE_CHECKING
from pydantic import BaseModel
from danswer.access.models import DocumentAccess
from danswer.configs.constants import DocumentSource
from danswer.connectors.models import Document
from danswer.utils.logger import setup_logger
@ -20,14 +16,12 @@ logger = setup_logger()
Embedding = list[float]
@dataclass
class ChunkEmbedding:
class ChunkEmbedding(BaseModel):
full_embedding: Embedding
mini_chunk_embeddings: list[Embedding]
@dataclass
class BaseChunk:
class BaseChunk(BaseModel):
chunk_id: int
blurb: str # The first sentence(s) of the first Section of the chunk
content: str
@ -37,7 +31,6 @@ class BaseChunk:
section_continuation: bool # True if this Chunk's start is not at the start of a Section
@dataclass
class DocAwareChunk(BaseChunk):
# During indexing flow, we have access to a complete "Document"
# During inference we only have access to the document id and do not reconstruct the Document
@ -50,13 +43,11 @@ class DocAwareChunk(BaseChunk):
)
@dataclass
class IndexChunk(DocAwareChunk):
embeddings: ChunkEmbedding
title_embedding: Embedding | None
@dataclass
class DocMetadataAwareIndexChunk(IndexChunk):
"""An `IndexChunk` that contains all necessary metadata to be indexed. This includes
the following:
@ -81,53 +72,15 @@ class DocMetadataAwareIndexChunk(IndexChunk):
document_sets: set[str],
boost: int,
) -> "DocMetadataAwareIndexChunk":
index_chunk_data = index_chunk.dict()
return cls(
**{
field.name: getattr(index_chunk, field.name)
for field in fields(index_chunk)
},
**index_chunk_data,
access=access,
document_sets=document_sets,
boost=boost,
)
@dataclass
class InferenceChunk(BaseChunk):
document_id: str
source_type: DocumentSource
semantic_identifier: str
boost: int
recency_bias: float
score: float | None
hidden: bool
metadata: dict[str, str | list[str]]
# Matched sections in the chunk. Uses Vespa syntax e.g. <hi>TEXT</hi>
# to specify that a set of words should be highlighted. For example:
# ["<hi>the</hi> <hi>answer</hi> is 42", "he couldn't find an <hi>answer</hi>"]
match_highlights: list[str]
# when the doc was last updated
updated_at: datetime | None
primary_owners: list[str] | None = None
secondary_owners: list[str] | None = None
@property
def unique_id(self) -> str:
return f"{self.document_id}__{self.chunk_id}"
def __repr__(self) -> str:
blurb_words = self.blurb.split()
short_blurb = ""
for word in blurb_words:
if not short_blurb:
short_blurb = word
continue
if len(short_blurb) > 25:
break
short_blurb += " " + word
return f"Inference Chunk: {self.document_id} - {short_blurb}..."
class EmbeddingModelDetail(BaseModel):
model_name: str
model_dim: int

View File

@ -31,15 +31,17 @@ from danswer.llm.utils import get_default_llm_tokenizer
def _get_stream_processor(
docs: list[LlmDoc], answer_style_configs: AnswerStyleConfig
context_docs: list[LlmDoc],
search_order_docs: list[LlmDoc],
answer_style_configs: AnswerStyleConfig,
) -> StreamProcessor:
if answer_style_configs.citation_config:
return build_citation_processor(
context_docs=docs,
context_docs=context_docs, search_order_docs=search_order_docs
)
if answer_style_configs.quotes_config:
return build_quotes_processor(
context_docs=docs, is_json_prompt=not (QA_PROMPT_OVERRIDE == "weak")
context_docs=context_docs, is_json_prompt=not (QA_PROMPT_OVERRIDE == "weak")
)
raise RuntimeError("Not implemented yet")
@ -83,8 +85,6 @@ class Answer:
)
self.llm_tokenizer = get_default_llm_tokenizer()
self.process_stream_fn = _get_stream_processor(docs, answer_style_config)
self._final_prompt: list[BaseMessage] | None = None
self._pruned_docs: list[LlmDoc] | None = None
@ -152,8 +152,14 @@ class Answer:
yield from self._processed_stream
return
process_stream_fn = _get_stream_processor(
context_docs=self.pruned_docs,
search_order_docs=self.docs,
answer_style_configs=self.answer_style_config,
)
processed_stream = []
for processed_packet in self.process_stream_fn(self.raw_streamed_output):
for processed_packet in process_stream_fn(self.raw_streamed_output):
processed_stream.append(processed_packet)
yield processed_packet

View File

@ -6,7 +6,6 @@ from danswer.chat.models import (
)
from danswer.configs.constants import IGNORE_FOR_QA
from danswer.configs.model_configs import DOC_EMBEDDING_CONTEXT_SIZE
from danswer.indexing.models import InferenceChunk
from danswer.llm.answering.models import DocumentPruningConfig
from danswer.llm.answering.models import LLMConfig
from danswer.llm.answering.models import PromptConfig
@ -14,6 +13,7 @@ from danswer.llm.answering.prompts.citations_prompt import compute_max_document_
from danswer.llm.utils import get_default_llm_tokenizer
from danswer.llm.utils import tokenizer_trim_content
from danswer.prompts.prompt_utils import build_doc_context_str
from danswer.search.models import InferenceChunk
from danswer.utils.logger import setup_logger
@ -87,6 +87,7 @@ def _apply_pruning(
doc_relevance_list: list[bool] | None,
token_limit: int,
is_manually_selected_docs: bool,
use_sections: bool,
) -> list[LlmDoc]:
llm_tokenizer = get_default_llm_tokenizer()
docs = deepcopy(docs) # don't modify in place
@ -117,6 +118,7 @@ def _apply_pruning(
# than the LLM tokenizer
if (
not is_manually_selected_docs
and not use_sections
and doc_tokens > DOC_EMBEDDING_CONTEXT_SIZE + _METADATA_TOKEN_ESTIMATE
):
logger.warning(
@ -136,13 +138,19 @@ def _apply_pruning(
break
if final_doc_ind is not None:
if is_manually_selected_docs:
if is_manually_selected_docs or use_sections:
# for document selection, only allow the final document to get truncated
# if more than that, then the user message is too long
if final_doc_ind != len(docs) - 1:
raise PruningError(
"LLM context window exceeded. Please de-select some documents or shorten your query."
)
if use_sections:
# Truncate the rest of the list since we're over the token limit
# for the last one, trim it. In this case, the Sections can be rather long
# so better to trim the back than throw away the whole thing.
docs = docs[: final_doc_ind + 1]
else:
raise PruningError(
"LLM context window exceeded. Please de-select some documents or shorten your query."
)
final_doc_desired_length = tokens_per_doc[final_doc_ind] - (
total_tokens - token_limit
@ -154,7 +162,7 @@ def _apply_pruning(
# not ideal, but it's the most reasonable thing to do
# NOTE: the frontend prevents documents from being selected if
# less than 75 tokens are available to try and avoid this situation
# from occuring in the first place
# from occurring in the first place
if final_doc_content_length <= 0:
logger.error(
f"Final doc ({docs[final_doc_ind].semantic_identifier}) content "
@ -168,7 +176,8 @@ def _apply_pruning(
tokenizer=llm_tokenizer,
)
else:
# for regular search, don't truncate the final document unless it's the only one
# For regular search, don't truncate the final document unless it's the only one
# If it's not the only one, we can throw it away, if it's the only one, we have to truncate
if final_doc_ind != 0:
docs = docs[:final_doc_ind]
else:
@ -206,4 +215,5 @@ def prune_documents(
doc_relevance_list=doc_relevance_list,
token_limit=doc_token_limit,
is_manually_selected_docs=document_pruning_config.is_manually_selected_docs,
use_sections=document_pruning_config.use_sections,
)

View File

@ -48,6 +48,9 @@ class DocumentPruningConfig(BaseModel):
# e.g. we don't want to truncate each document to be no more
# than one chunk long
is_manually_selected_docs: bool = False
# If user specifies to include additional context chunks for each match, then different pruning
# is used. As many Sections as possible are included, and the last Section is truncated
use_sections: bool = False
class CitationConfig(BaseModel):

View File

@ -11,7 +11,6 @@ from danswer.configs.chat_configs import MULTILINGUAL_QUERY_EXPANSION
from danswer.configs.model_configs import GEN_AI_SINGLE_USER_MESSAGE_EXPECTED_MAX_TOKENS
from danswer.db.chat import get_default_prompt
from danswer.db.models import Persona
from danswer.indexing.models import InferenceChunk
from danswer.llm.answering.models import LLMConfig
from danswer.llm.answering.models import PreviousMessage
from danswer.llm.answering.models import PromptConfig
@ -37,6 +36,7 @@ from danswer.prompts.token_counts import (
from danswer.prompts.token_counts import CITATION_REMINDER_TOKEN_CNT
from danswer.prompts.token_counts import CITATION_STATEMENT_TOKEN_CNT
from danswer.prompts.token_counts import LANGUAGE_HINT_TOKEN_CNT
from danswer.search.models import InferenceChunk
_PER_MESSAGE_TOKEN_BUFFER = 7

View File

@ -4,7 +4,6 @@ from langchain.schema.messages import HumanMessage
from danswer.chat.models import LlmDoc
from danswer.configs.chat_configs import MULTILINGUAL_QUERY_EXPANSION
from danswer.configs.chat_configs import QA_PROMPT_OVERRIDE
from danswer.indexing.models import InferenceChunk
from danswer.llm.answering.models import PromptConfig
from danswer.prompts.direct_qa_prompts import CONTEXT_BLOCK
from danswer.prompts.direct_qa_prompts import HISTORY_BLOCK
@ -12,6 +11,7 @@ from danswer.prompts.direct_qa_prompts import JSON_PROMPT
from danswer.prompts.direct_qa_prompts import LANGUAGE_HINT
from danswer.prompts.direct_qa_prompts import WEAK_LLM_PROMPT
from danswer.prompts.prompt_utils import build_complete_context_str
from danswer.search.models import InferenceChunk
def _build_weak_llm_quotes_prompt(

View File

@ -114,13 +114,13 @@ def extract_citations_from_stream(
def build_citation_processor(
context_docs: list[LlmDoc],
context_docs: list[LlmDoc], search_order_docs: list[LlmDoc]
) -> StreamProcessor:
def stream_processor(tokens: Iterator[str]) -> AnswerQuestionStreamReturn:
yield from extract_citations_from_stream(
tokens=tokens,
context_docs=context_docs,
doc_id_to_rank_map=map_document_id_order(context_docs),
doc_id_to_rank_map=map_document_id_order(search_order_docs),
)
return stream_processor

View File

@ -15,10 +15,10 @@ from danswer.chat.models import DanswerQuote
from danswer.chat.models import DanswerQuotes
from danswer.chat.models import LlmDoc
from danswer.configs.chat_configs import QUOTE_ALLOWED_ERROR_PERCENT
from danswer.indexing.models import InferenceChunk
from danswer.prompts.constants import ANSWER_PAT
from danswer.prompts.constants import QUOTE_PAT
from danswer.prompts.constants import UNCERTAINTY_PAT
from danswer.search.models import InferenceChunk
from danswer.utils.logger import setup_logger
from danswer.utils.text_processing import clean_model_quote
from danswer.utils.text_processing import clean_up_code_blocks

View File

@ -1,7 +1,7 @@
from collections.abc import Sequence
from danswer.chat.models import LlmDoc
from danswer.indexing.models import InferenceChunk
from danswer.search.models import InferenceChunk
def map_document_id_order(

View File

@ -33,8 +33,8 @@ from danswer.configs.model_configs import GEN_AI_MODEL_VERSION
from danswer.db.models import ChatMessage
from danswer.dynamic_configs.factory import get_dynamic_config_store
from danswer.dynamic_configs.interface import ConfigNotFoundError
from danswer.indexing.models import InferenceChunk
from danswer.llm.interfaces import LLM
from danswer.search.models import InferenceChunk
from danswer.utils.logger import setup_logger
from shared_configs.configs import LOG_LEVEL

View File

@ -3,7 +3,7 @@ from collections.abc import Iterator
from sqlalchemy.orm import Session
from danswer.chat.chat_utils import llm_doc_from_inference_chunk
from danswer.chat.chat_utils import llm_doc_from_inference_section
from danswer.chat.chat_utils import reorganize_citations
from danswer.chat.models import CitationInfo
from danswer.chat.models import DanswerAnswerPiece
@ -40,7 +40,7 @@ from danswer.search.models import RerankMetricsContainer
from danswer.search.models import RetrievalMetricsContainer
from danswer.search.models import SearchRequest
from danswer.search.pipeline import SearchPipeline
from danswer.search.utils import chunks_to_search_docs
from danswer.search.utils import chunks_or_sections_to_search_docs
from danswer.secondary_llm_flows.answer_validation import get_answer_validity
from danswer.secondary_llm_flows.query_expansion import thread_based_query_rephrase
from danswer.server.query_and_chat.models import ChatMessageDetail
@ -128,6 +128,9 @@ def stream_answer_objects(
limit=query_req.retrieval_options.limit,
skip_rerank=query_req.skip_rerank,
skip_llm_chunk_filter=query_req.skip_llm_chunk_filter,
chunks_above=query_req.chunks_above,
chunks_below=query_req.chunks_below,
full_doc=query_req.full_doc,
),
user=user,
db_session=db_session,
@ -137,8 +140,8 @@ def stream_answer_objects(
)
# First fetch and return the top chunks so the user can immediately see some results
top_chunks = search_pipeline.reranked_docs
top_docs = chunks_to_search_docs(top_chunks)
top_sections = search_pipeline.reranked_sections
top_docs = chunks_or_sections_to_search_docs(top_sections)
reference_db_search_docs = [
create_db_search_doc(server_search_doc=top_doc, db_session=db_session)
@ -163,7 +166,7 @@ def stream_answer_objects(
# Yield the list of LLM selected chunks for showing the LLM selected icons in the UI
llm_relevance_filtering_response = LLMRelevanceFilterResponse(
relevant_chunk_indices=search_pipeline.relevant_chunk_indicies
relevant_chunk_indices=search_pipeline.relevant_chunk_indices
)
yield llm_relevance_filtering_response
@ -201,15 +204,16 @@ def stream_answer_objects(
else default_num_chunks
),
max_tokens=max_document_tokens,
use_sections=search_pipeline.ran_merge_chunk,
),
)
answer = Answer(
question=query_msg.message,
docs=[llm_doc_from_inference_chunk(chunk) for chunk in top_chunks],
docs=[llm_doc_from_inference_section(section) for section in top_sections],
answer_style_config=answer_config,
prompt_config=PromptConfig.from_model(prompt),
llm_config=LLMConfig.from_persona(chat_session.persona),
doc_relevance_list=search_pipeline.chunk_relevance_list,
doc_relevance_list=search_pipeline.section_relevance_list,
single_message_history=history_str,
timeout=timeout,
)

View File

@ -9,6 +9,7 @@ from danswer.chat.models import DanswerContexts
from danswer.chat.models import DanswerQuotes
from danswer.chat.models import QADocsResponse
from danswer.configs.constants import MessageType
from danswer.search.models import ChunkContext
from danswer.search.models import RetrievalDetails
@ -22,7 +23,7 @@ class ThreadMessage(BaseModel):
role: MessageType = MessageType.USER
class DirectQARequest(BaseModel):
class DirectQARequest(ChunkContext):
messages: list[ThreadMessage]
prompt_id: int | None
persona_id: int

View File

@ -5,11 +5,11 @@ from danswer.chat.models import LlmDoc
from danswer.configs.chat_configs import MULTILINGUAL_QUERY_EXPANSION
from danswer.configs.constants import DocumentSource
from danswer.db.models import Prompt
from danswer.indexing.models import InferenceChunk
from danswer.llm.answering.models import PromptConfig
from danswer.prompts.chat_prompts import CITATION_REMINDER
from danswer.prompts.constants import CODE_BLOCK_PAT
from danswer.prompts.direct_qa_prompts import LANGUAGE_HINT
from danswer.search.models import InferenceChunk
def get_current_llm_day_time() -> str:

View File

@ -2,6 +2,7 @@ from datetime import datetime
from typing import Any
from pydantic import BaseModel
from pydantic import validator
from danswer.configs.chat_configs import DISABLE_LLM_CHUNK_FILTER
from danswer.configs.chat_configs import HYBRID_ALPHA
@ -9,6 +10,7 @@ from danswer.configs.chat_configs import NUM_RERANKED_RESULTS
from danswer.configs.chat_configs import NUM_RETURNED_HITS
from danswer.configs.constants import DocumentSource
from danswer.db.models import Persona
from danswer.indexing.models import BaseChunk
from danswer.search.enums import OptionalSearchSetting
from danswer.search.enums import SearchType
from shared_configs.configs import ENABLE_RERANKING_REAL_TIME_FLOW
@ -42,7 +44,21 @@ class ChunkMetric(BaseModel):
score: float
class SearchRequest(BaseModel):
class ChunkContext(BaseModel):
# Additional surrounding context options, if full doc, then chunks are deduped
# If surrounding context overlap, it is combined into one
chunks_above: int = 0
chunks_below: int = 0
full_doc: bool = False
@validator("chunks_above", "chunks_below", pre=True, each_item=False)
def check_non_negative(cls, value: int, field: Any) -> int:
if value < 0:
raise ValueError(f"{field.name} must be non-negative")
return value
class SearchRequest(ChunkContext):
"""Input to the SearchPipeline."""
query: str
@ -66,7 +82,7 @@ class SearchRequest(BaseModel):
arbitrary_types_allowed = True
class SearchQuery(BaseModel):
class SearchQuery(ChunkContext):
query: str
filters: IndexFilters
recency_bias_multiplier: float
@ -84,7 +100,7 @@ class SearchQuery(BaseModel):
frozen = True
class RetrievalDetails(BaseModel):
class RetrievalDetails(ChunkContext):
# Use LLM to determine whether to do a retrieval or only rely on existing history
# If the Persona is configured to not run search (0 chunks), this is bypassed
# If no Prompt is configured, the only search results are shown, this is bypassed
@ -92,7 +108,7 @@ class RetrievalDetails(BaseModel):
# Is this a real-time/streaming call or a question where Danswer can take more time?
# Used to determine reranking flow
real_time: bool = True
# The following have defaults in the Persona settings which can be overriden via
# The following have defaults in the Persona settings which can be overridden via
# the query, if None, then use Persona settings
filters: BaseFilters | None = None
enable_auto_detect_filters: bool | None = None
@ -101,6 +117,63 @@ class RetrievalDetails(BaseModel):
limit: int | None = None
class InferenceChunk(BaseChunk):
document_id: str
source_type: DocumentSource
semantic_identifier: str
boost: int
recency_bias: float
score: float | None
hidden: bool
metadata: dict[str, str | list[str]]
# Matched sections in the chunk. Uses Vespa syntax e.g. <hi>TEXT</hi>
# to specify that a set of words should be highlighted. For example:
# ["<hi>the</hi> <hi>answer</hi> is 42", "he couldn't find an <hi>answer</hi>"]
match_highlights: list[str]
# when the doc was last updated
updated_at: datetime | None
primary_owners: list[str] | None = None
secondary_owners: list[str] | None = None
@property
def unique_id(self) -> str:
return f"{self.document_id}__{self.chunk_id}"
def __repr__(self) -> str:
blurb_words = self.blurb.split()
short_blurb = ""
for word in blurb_words:
if not short_blurb:
short_blurb = word
continue
if len(short_blurb) > 25:
break
short_blurb += " " + word
return f"Inference Chunk: {self.document_id} - {short_blurb}..."
def __eq__(self, other: Any) -> bool:
if not isinstance(other, InferenceChunk):
return False
return (self.document_id, self.chunk_id) == (other.document_id, other.chunk_id)
def __hash__(self) -> int:
return hash((self.document_id, self.chunk_id))
class InferenceSection(InferenceChunk):
"""Section is a combination of chunks. A section could be a single chunk, several consecutive
chunks or the entire document"""
combined_content: str
@classmethod
def from_chunk(
cls, inf_chunk: InferenceChunk, content: str | None = None
) -> "InferenceSection":
inf_chunk_data = inf_chunk.dict()
return cls(**inf_chunk_data, combined_content=content or inf_chunk.content)
class SearchDoc(BaseModel):
document_id: str
chunk_ind: int

View File

@ -1,16 +1,20 @@
from collections import defaultdict
from collections.abc import Callable
from collections.abc import Generator
from typing import cast
from pydantic import BaseModel
from sqlalchemy.orm import Session
from danswer.configs.chat_configs import MULTILINGUAL_QUERY_EXPANSION
from danswer.db.embedding_model import get_current_db_embedding_model
from danswer.db.models import User
from danswer.document_index.factory import get_default_document_index
from danswer.indexing.models import InferenceChunk
from danswer.search.enums import QueryFlow
from danswer.search.enums import SearchType
from danswer.search.models import IndexFilters
from danswer.search.models import InferenceChunk
from danswer.search.models import InferenceSection
from danswer.search.models import RerankMetricsContainer
from danswer.search.models import RetrievalMetricsContainer
from danswer.search.models import SearchQuery
@ -18,6 +22,31 @@ from danswer.search.models import SearchRequest
from danswer.search.postprocessing.postprocessing import search_postprocessing
from danswer.search.preprocessing.preprocessing import retrieval_preprocessing
from danswer.search.retrieval.search_runner import retrieve_chunks
from danswer.utils.threadpool_concurrency import run_functions_tuples_in_parallel
class ChunkRange(BaseModel):
chunk: InferenceChunk
start: int
end: int
combined_content: str | None = None
def merge_chunk_intervals(chunk_ranges: list[ChunkRange]) -> list[ChunkRange]:
"""This acts on a single document to merge the overlapping ranges of sections
Algo explained here for easy understanding: https://leetcode.com/problems/merge-intervals
"""
sorted_ranges = sorted(chunk_ranges, key=lambda x: x.start)
ans: list[ChunkRange] = []
for chunk_range in sorted_ranges:
if not ans or ans[-1].end < chunk_range.start:
ans.append(chunk_range)
else:
ans[-1].end = max(ans[-1].end, chunk_range.end)
return ans
class SearchPipeline:
@ -48,15 +77,148 @@ class SearchPipeline:
self._predicted_search_type: SearchType | None = None
self._predicted_flow: QueryFlow | None = None
self._retrieved_docs: list[InferenceChunk] | None = None
self._reranked_docs: list[InferenceChunk] | None = None
self._relevant_chunk_indicies: list[int] | None = None
self._retrieved_chunks: list[InferenceChunk] | None = None
self._retrieved_sections: list[InferenceSection] | None = None
self._reranked_chunks: list[InferenceChunk] | None = None
self._reranked_sections: list[InferenceSection] | None = None
self._relevant_chunk_indices: list[int] | None = None
# If chunks have been merged, the LLM filter flow no longer applies
# as the indices no longer match. Can be implemented later as needed
self.ran_merge_chunk = False
# generator state
self._postprocessing_generator: Generator[
list[InferenceChunk] | list[str], None, None
] | None = None
def _combine_chunks(self, post_rerank: bool) -> list[InferenceSection]:
if not post_rerank and self._retrieved_sections:
return self._retrieved_sections
if post_rerank and self._reranked_sections:
return self._reranked_sections
if not post_rerank:
chunks = self.retrieved_chunks
else:
chunks = self.reranked_chunks
if self._search_query is None:
# Should never happen
raise RuntimeError("Failed in Query Preprocessing")
functions_with_args: list[tuple[Callable, tuple]] = []
final_inference_sections = []
# Nothing to combine, just return the chunks
if (
not self._search_query.chunks_above
and not self._search_query.chunks_below
and not self._search_query.full_doc
):
return [InferenceSection.from_chunk(chunk) for chunk in chunks]
# If chunk merges have been run, LLM reranking loses meaning
# Needs reimplementation, out of scope for now
self.ran_merge_chunk = True
# Full doc setting takes priority
if self._search_query.full_doc:
seen_document_ids = set()
unique_chunks = []
for chunk in chunks:
if chunk.document_id not in seen_document_ids:
seen_document_ids.add(chunk.document_id)
unique_chunks.append(chunk)
functions_with_args.append(
(
self.document_index.id_based_retrieval,
(
chunk.document_id,
None, # Start chunk ind
None, # End chunk ind
# There is no chunk level permissioning, this expansion around chunks
# can be assumed to be safe
IndexFilters(access_control_list=None),
),
)
)
list_inference_chunks = run_functions_tuples_in_parallel(
functions_with_args, allow_failures=False
)
for ind, chunk in enumerate(unique_chunks):
inf_chunks = list_inference_chunks[ind]
combined_content = "\n".join([chunk.content for chunk in inf_chunks])
final_inference_sections.append(
InferenceSection.from_chunk(chunk, content=combined_content)
)
return final_inference_sections
# General flow:
# - Combine chunks into lists by document_id
# - For each document, run merge-intervals to get combined ranges
# - Fetch all of the new chunks with contents for the combined ranges
# - Map it back to the combined ranges (which each know their "center" chunk)
# - Reiterate the chunks again and map to the results above based on the chunk.
# This maintains the original chunks ordering. Note, we cannot simply sort by score here
# as reranking flow may wipe the scores for a lot of the chunks.
doc_chunk_ranges_map = defaultdict(list)
for chunk in chunks:
doc_chunk_ranges_map[chunk.document_id].append(
ChunkRange(
chunk=chunk,
start=max(0, chunk.chunk_id - self._search_query.chunks_above),
# No max known ahead of time, filter will handle this anyway
end=chunk.chunk_id + self._search_query.chunks_below,
)
)
merged_ranges = [
merge_chunk_intervals(ranges) for ranges in doc_chunk_ranges_map.values()
]
reverse_map = {r.chunk: r for doc_ranges in merged_ranges for r in doc_ranges}
for chunk_range in reverse_map.values():
functions_with_args.append(
(
self.document_index.id_based_retrieval,
(
chunk_range.chunk.document_id,
chunk_range.start,
chunk_range.end,
# There is no chunk level permissioning, this expansion around chunks
# can be assumed to be safe
IndexFilters(access_control_list=None),
),
)
)
# list of list of inference chunks where the inner list needs to be combined for content
list_inference_chunks = run_functions_tuples_in_parallel(
functions_with_args, allow_failures=False
)
for ind, chunk_range in enumerate(reverse_map.values()):
inf_chunks = list_inference_chunks[ind]
combined_content = "\n".join([chunk.content for chunk in inf_chunks])
chunk_range.combined_content = combined_content
for chunk in chunks:
if chunk not in reverse_map:
continue
chunk_range = reverse_map[chunk]
final_inference_sections.append(
InferenceSection.from_chunk(
chunk_range.chunk, content=chunk_range.combined_content
)
)
return final_inference_sections
"""Pre-processing"""
def _run_preprocessing(self) -> None:
@ -101,11 +263,11 @@ class SearchPipeline:
"""Retrieval"""
@property
def retrieved_docs(self) -> list[InferenceChunk]:
if self._retrieved_docs is not None:
return self._retrieved_docs
def retrieved_chunks(self) -> list[InferenceChunk]:
if self._retrieved_chunks is not None:
return self._retrieved_chunks
self._retrieved_docs = retrieve_chunks(
self._retrieved_chunks = retrieve_chunks(
query=self.search_query,
document_index=self.document_index,
db_session=self.db_session,
@ -114,47 +276,75 @@ class SearchPipeline:
retrieval_metrics_callback=self.retrieval_metrics_callback,
)
# self._retrieved_docs = chunks_to_search_docs(retrieved_chunks)
return cast(list[InferenceChunk], self._retrieved_docs)
return cast(list[InferenceChunk], self._retrieved_chunks)
@property
def retrieved_sections(self) -> list[InferenceSection]:
# Calls retrieved_chunks inside
self._retrieved_sections = self._combine_chunks(post_rerank=False)
return self._retrieved_sections
"""Post-Processing"""
@property
def reranked_docs(self) -> list[InferenceChunk]:
if self._reranked_docs is not None:
return self._reranked_docs
def reranked_chunks(self) -> list[InferenceChunk]:
if self._reranked_chunks is not None:
return self._reranked_chunks
self._postprocessing_generator = search_postprocessing(
search_query=self.search_query,
retrieved_chunks=self.retrieved_docs,
retrieved_chunks=self.retrieved_chunks,
rerank_metrics_callback=self.rerank_metrics_callback,
)
self._reranked_docs = cast(
self._reranked_chunks = cast(
list[InferenceChunk], next(self._postprocessing_generator)
)
return self._reranked_docs
return self._reranked_chunks
@property
def relevant_chunk_indicies(self) -> list[int]:
if self._relevant_chunk_indicies is not None:
return self._relevant_chunk_indicies
def reranked_sections(self) -> list[InferenceSection]:
# Calls reranked_chunks inside
self._reranked_sections = self._combine_chunks(post_rerank=True)
return self._reranked_sections
@property
def relevant_chunk_indices(self) -> list[int]:
# If chunks have been merged, then we cannot simply rely on the leading chunk
# relevance, there is no way to get the full relevance of the Section now
# without running a more token heavy pass. This can be an option but not
# implementing now.
if self.ran_merge_chunk:
return []
if self._relevant_chunk_indices is not None:
return self._relevant_chunk_indices
# run first step of postprocessing generator if not already done
reranked_docs = self.reranked_docs
reranked_docs = self.reranked_chunks
relevant_chunk_ids = next(
cast(Generator[list[str], None, None], self._postprocessing_generator)
)
self._relevant_chunk_indicies = [
self._relevant_chunk_indices = [
ind
for ind, chunk in enumerate(reranked_docs)
if chunk.unique_id in relevant_chunk_ids
]
return self._relevant_chunk_indicies
return self._relevant_chunk_indices
@property
def chunk_relevance_list(self) -> list[bool]:
return [
True if ind in self.relevant_chunk_indicies else False
for ind in range(len(self.reranked_docs))
True if ind in self.relevant_chunk_indices else False
for ind in range(len(self.reranked_chunks))
]
@property
def section_relevance_list(self) -> list[bool]:
if self.ran_merge_chunk:
return [False] * len(self.reranked_sections)
return [
True if ind in self.relevant_chunk_indices else False
for ind in range(len(self.reranked_chunks))
]

View File

@ -9,8 +9,8 @@ from danswer.configs.model_configs import CROSS_ENCODER_RANGE_MIN
from danswer.document_index.document_index_utils import (
translate_boost_count_to_multiplier,
)
from danswer.indexing.models import InferenceChunk
from danswer.search.models import ChunkMetric
from danswer.search.models import InferenceChunk
from danswer.search.models import MAX_METRICS_CONTENT
from danswer.search.models import RerankMetricsContainer
from danswer.search.models import SearchQuery

View File

@ -181,6 +181,9 @@ def retrieval_preprocessing(
offset=offset or 0,
skip_rerank=skip_rerank,
skip_llm_chunk_filter=not llm_chunk_filter,
chunks_above=search_request.chunks_above,
chunks_below=search_request.chunks_below,
full_doc=search_request.full_doc,
),
predicted_search_type,
predicted_flow,

View File

@ -11,10 +11,10 @@ from danswer.configs.chat_configs import HYBRID_ALPHA
from danswer.configs.chat_configs import MULTILINGUAL_QUERY_EXPANSION
from danswer.db.embedding_model import get_current_db_embedding_model
from danswer.document_index.interfaces import DocumentIndex
from danswer.indexing.models import InferenceChunk
from danswer.search.enums import EmbedTextType
from danswer.search.models import ChunkMetric
from danswer.search.models import IndexFilters
from danswer.search.models import InferenceChunk
from danswer.search.models import MAX_METRICS_CONTENT
from danswer.search.models import RetrievalMetricsContainer
from danswer.search.models import SearchQuery
@ -244,7 +244,7 @@ def inference_documents_from_ids(
filters = IndexFilters(access_control_list=None)
functions_with_args: list[tuple[Callable, tuple]] = [
(document_index.id_based_retrieval, (doc_id, None, filters))
(document_index.id_based_retrieval, (doc_id, None, None, filters))
for doc_id in doc_ids_set
]

View File

@ -1,8 +1,13 @@
from danswer.indexing.models import InferenceChunk
from collections.abc import Sequence
from danswer.search.models import InferenceChunk
from danswer.search.models import InferenceSection
from danswer.search.models import SearchDoc
def chunks_to_search_docs(chunks: list[InferenceChunk] | None) -> list[SearchDoc]:
def chunks_or_sections_to_search_docs(
chunks: Sequence[InferenceChunk | InferenceSection] | None,
) -> list[SearchDoc]:
search_docs = (
[
SearchDoc(

View File

@ -39,7 +39,8 @@ def get_document_info(
inference_chunks = document_index.id_based_retrieval(
document_id=document_id,
chunk_ind=None,
min_chunk_ind=None,
max_chunk_ind=None,
filters=filters,
)
@ -86,7 +87,8 @@ def get_chunk_info(
inference_chunks = document_index.id_based_retrieval(
document_id=document_id,
chunk_ind=chunk_id,
min_chunk_ind=chunk_id,
max_chunk_ind=chunk_id,
filters=filters,
)

View File

@ -72,7 +72,7 @@ def gpt_search(
),
user=None,
db_session=db_session,
).reranked_docs
).reranked_chunks
return GptSearchResponse(
matching_document_chunks=[

View File

@ -12,6 +12,7 @@ from danswer.db.enums import ChatSessionSharedStatus
from danswer.llm.override_models import LLMOverride
from danswer.llm.override_models import PromptOverride
from danswer.search.models import BaseFilters
from danswer.search.models import ChunkContext
from danswer.search.models import RetrievalDetails
from danswer.search.models import SearchDoc
from danswer.search.models import SearchType
@ -83,7 +84,7 @@ Currently the different branches are generated by changing the search query
"""
class CreateChatMessageRequest(BaseModel):
class CreateChatMessageRequest(ChunkContext):
"""Before creating messages, be sure to create a chat_session and get an id"""
chat_session_id: int

View File

@ -19,7 +19,7 @@ from danswer.search.models import IndexFilters
from danswer.search.models import SearchDoc
from danswer.search.preprocessing.access_filters import build_access_filters_for_user
from danswer.search.preprocessing.danswer_helper import recommend_search_flow
from danswer.search.utils import chunks_to_search_docs
from danswer.search.utils import chunks_or_sections_to_search_docs
from danswer.secondary_llm_flows.query_validation import get_query_answerability
from danswer.secondary_llm_flows.query_validation import stream_query_answerability
from danswer.server.query_and_chat.models import AdminSearchRequest
@ -69,7 +69,7 @@ def admin_search(
matching_chunks = document_index.admin_retrieval(query=query, filters=final_filters)
documents = chunks_to_search_docs(matching_chunks)
documents = chunks_or_sections_to_search_docs(matching_chunks)
# Deduplicate documents by id
deduplicated_documents: list[SearchDoc] = []

View File

@ -8,8 +8,8 @@ from typing import TextIO
from sqlalchemy.orm import Session
from danswer.db.engine import get_sqlalchemy_engine
from danswer.indexing.models import InferenceChunk
from danswer.llm.answering.doc_pruning import reorder_docs
from danswer.search.models import InferenceChunk
from danswer.search.models import RerankMetricsContainer
from danswer.search.models import RetrievalMetricsContainer
from danswer.search.models import SearchRequest
@ -92,7 +92,7 @@ def get_search_results(
rerank_metrics_callback=rerank_metrics.record_metric,
)
top_chunks = search_pipeline.reranked_docs
top_chunks = search_pipeline.reranked_chunks
llm_chunk_selection = search_pipeline.chunk_relevance_list
return (

View File

@ -2,13 +2,13 @@ import textwrap
import unittest
from danswer.configs.constants import DocumentSource
from danswer.indexing.models import InferenceChunk
from danswer.llm.answering.stream_processing.quotes_processing import (
match_quotes_to_docs,
)
from danswer.llm.answering.stream_processing.quotes_processing import (
separate_answer_quotes,
)
from danswer.search.models import InferenceChunk
class TestQAPostprocessing(unittest.TestCase):