diff --git a/backend/danswer/chat/process_message.py b/backend/danswer/chat/process_message.py index 25f392ebd73c..c8153bd156d5 100644 --- a/backend/danswer/chat/process_message.py +++ b/backend/danswer/chat/process_message.py @@ -270,6 +270,11 @@ def stream_chat_message_objects( 3. [always] A set of streamed LLM tokens or an error anywhere along the line if something fails 4. [always] Details on the final AI response message that is created """ + # Currently surrounding context is not supported for chat + # Chat is already token heavy and harder for the model to process plus it would roll history over much faster + new_msg_req.chunks_above = 0 + new_msg_req.chunks_below = 0 + try: user_id = user.id if user is not None else None diff --git a/backend/danswer/configs/chat_configs.py b/backend/danswer/configs/chat_configs.py index d4f5a5e807ea..2b6b0990e1d5 100644 --- a/backend/danswer/configs/chat_configs.py +++ b/backend/danswer/configs/chat_configs.py @@ -31,8 +31,9 @@ FAVOR_RECENT_DECAY_MULTIPLIER = 2.0 DISABLE_LLM_QUERY_ANSWERABILITY = QA_PROMPT_OVERRIDE == "weak" # For the highest matching base size chunk, how many chunks above and below do we pull in by default # Note this is not in any of the deployment configs yet -CONTEXT_CHUNKS_ABOVE = int(os.environ.get("CONTEXT_CHUNKS_ABOVE") or 0) -CONTEXT_CHUNKS_BELOW = int(os.environ.get("CONTEXT_CHUNKS_BELOW") or 0) +# Currently only applies to search flow not chat +CONTEXT_CHUNKS_ABOVE = int(os.environ.get("CONTEXT_CHUNKS_ABOVE") or 1) +CONTEXT_CHUNKS_BELOW = int(os.environ.get("CONTEXT_CHUNKS_BELOW") or 1) # Whether the LLM should be used to decide if a search would help given the chat history DISABLE_LLM_CHOOSE_SEARCH = ( os.environ.get("DISABLE_LLM_CHOOSE_SEARCH", "").lower() == "true" diff --git a/backend/danswer/configs/model_configs.py b/backend/danswer/configs/model_configs.py index e2b8ee7f62a4..e7dce6f12f5a 100644 --- a/backend/danswer/configs/model_configs.py +++ b/backend/danswer/configs/model_configs.py @@ -82,6 +82,9 @@ GEN_AI_LLM_PROVIDER_TYPE = os.environ.get("GEN_AI_LLM_PROVIDER_TYPE") or None GEN_AI_MAX_TOKENS = int(os.environ.get("GEN_AI_MAX_TOKENS") or 0) or None # Set this to be enough for an answer + quotes. Also used for Chat GEN_AI_MAX_OUTPUT_TOKENS = int(os.environ.get("GEN_AI_MAX_OUTPUT_TOKENS") or 1024) + +# Typically, GenAI models nowadays are at least 4K tokens +GEN_AI_MODEL_DEFAULT_MAX_TOKENS = 4096 # Number of tokens from chat history to include at maximum # 3000 should be enough context regardless of use, no need to include as much as possible # as this drives up the cost unnecessarily diff --git a/backend/danswer/db/models.py b/backend/danswer/db/models.py index c92b6c3c153b..025974bb015e 100644 --- a/backend/danswer/db/models.py +++ b/backend/danswer/db/models.py @@ -1206,6 +1206,8 @@ class Persona(Base): description: Mapped[str] = mapped_column(String) # Number of chunks to pass to the LLM for generation. num_chunks: Mapped[float | None] = mapped_column(Float, nullable=True) + chunks_above: Mapped[int] = mapped_column(Integer) + chunks_below: Mapped[int] = mapped_column(Integer) # Pass every chunk through LLM for evaluation, fairly expensive # Can be turned off globally by admin, in which case, this setting is ignored llm_relevance_filter: Mapped[bool] = mapped_column(Boolean) diff --git a/backend/danswer/db/persona.py b/backend/danswer/db/persona.py index 8c25b27b961f..d8f71c3eabbd 100644 --- a/backend/danswer/db/persona.py +++ b/backend/danswer/db/persona.py @@ -15,6 +15,8 @@ from sqlalchemy.orm import Session from danswer.auth.schemas import UserRole from danswer.configs.chat_configs import BING_API_KEY +from danswer.configs.chat_configs import CONTEXT_CHUNKS_ABOVE +from danswer.configs.chat_configs import CONTEXT_CHUNKS_BELOW from danswer.db.constants import SLACK_BOT_PERSONA_PREFIX from danswer.db.engine import get_sqlalchemy_engine from danswer.db.models import DocumentSet @@ -67,28 +69,14 @@ def create_update_persona( # Permission to actually use these is checked later try: - persona = upsert_persona( - persona_id=persona_id, - user=user, - name=create_persona_request.name, - description=create_persona_request.description, - num_chunks=create_persona_request.num_chunks, - llm_relevance_filter=create_persona_request.llm_relevance_filter, - llm_filter_extraction=create_persona_request.llm_filter_extraction, - recency_bias=create_persona_request.recency_bias, - prompt_ids=create_persona_request.prompt_ids, - tool_ids=create_persona_request.tool_ids, - document_set_ids=create_persona_request.document_set_ids, - llm_model_provider_override=create_persona_request.llm_model_provider_override, - llm_model_version_override=create_persona_request.llm_model_version_override, - starter_messages=create_persona_request.starter_messages, - is_public=create_persona_request.is_public, - db_session=db_session, - icon_color=create_persona_request.icon_color, - icon_shape=create_persona_request.icon_shape, - uploaded_image_id=create_persona_request.uploaded_image_id, - remove_image=create_persona_request.remove_image, - ) + persona_data = { + "persona_id": persona_id, + "user": user, + "db_session": db_session, + **create_persona_request.dict(exclude={"users", "groups"}), + } + + persona = upsert_persona(**persona_data) versioned_make_persona_private = fetch_versioned_implementation( "danswer.db.persona", "make_persona_private" @@ -352,6 +340,8 @@ def upsert_persona( display_priority: int | None = None, is_visible: bool = True, remove_image: bool | None = None, + chunks_above: int = CONTEXT_CHUNKS_ABOVE, + chunks_below: int = CONTEXT_CHUNKS_BELOW, ) -> Persona: if persona_id is not None: persona = db_session.query(Persona).filter_by(id=persona_id).first() @@ -398,6 +388,8 @@ def upsert_persona( persona.name = name persona.description = description persona.num_chunks = num_chunks + persona.chunks_above = chunks_above + persona.chunks_below = chunks_below persona.llm_relevance_filter = llm_relevance_filter persona.llm_filter_extraction = llm_filter_extraction persona.recency_bias = recency_bias @@ -435,6 +427,8 @@ def upsert_persona( name=name, description=description, num_chunks=num_chunks, + chunks_above=chunks_above, + chunks_below=chunks_below, llm_relevance_filter=llm_relevance_filter, llm_filter_extraction=llm_filter_extraction, recency_bias=recency_bias, diff --git a/backend/danswer/llm/answering/models.py b/backend/danswer/llm/answering/models.py index 432ea7338a4d..3d05a08c47b6 100644 --- a/backend/danswer/llm/answering/models.py +++ b/backend/danswer/llm/answering/models.py @@ -92,6 +92,16 @@ class DocumentPruningConfig(BaseModel): using_tool_message: bool = False +class ContextualPruningConfig(DocumentPruningConfig): + num_chunk_multiple: int + + @classmethod + def from_doc_pruning_config( + cls, num_chunk_multiple: int, doc_pruning_config: DocumentPruningConfig + ) -> "ContextualPruningConfig": + return cls(num_chunk_multiple=num_chunk_multiple, **doc_pruning_config.dict()) + + class CitationConfig(BaseModel): all_docs_useful: bool = False diff --git a/backend/danswer/llm/answering/prune_and_merge.py b/backend/danswer/llm/answering/prune_and_merge.py index 43f592a8385e..0193de1f2aae 100644 --- a/backend/danswer/llm/answering/prune_and_merge.py +++ b/backend/danswer/llm/answering/prune_and_merge.py @@ -10,7 +10,7 @@ from danswer.chat.models import ( ) from danswer.configs.constants import IGNORE_FOR_QA from danswer.configs.model_configs import DOC_EMBEDDING_CONTEXT_SIZE -from danswer.llm.answering.models import DocumentPruningConfig +from danswer.llm.answering.models import ContextualPruningConfig from danswer.llm.answering.models import PromptConfig from danswer.llm.answering.prompts.citations_prompt import compute_max_document_tokens from danswer.llm.interfaces import LLMConfig @@ -266,29 +266,36 @@ def prune_sections( prompt_config: PromptConfig, llm_config: LLMConfig, question: str, - document_pruning_config: DocumentPruningConfig, + contextual_pruning_config: ContextualPruningConfig, ) -> list[InferenceSection]: # Assumes the sections are score ordered with highest first if section_relevance_list is not None: assert len(sections) == len(section_relevance_list) + actual_num_chunks = ( + contextual_pruning_config.max_chunks + * contextual_pruning_config.num_chunk_multiple + if contextual_pruning_config.max_chunks + else None + ) + token_limit = _compute_limit( prompt_config=prompt_config, llm_config=llm_config, question=question, - max_chunks=document_pruning_config.max_chunks, - max_window_percentage=document_pruning_config.max_window_percentage, - max_tokens=document_pruning_config.max_tokens, - tool_token_count=document_pruning_config.tool_num_tokens, + max_chunks=actual_num_chunks, + max_window_percentage=contextual_pruning_config.max_window_percentage, + max_tokens=contextual_pruning_config.max_tokens, + tool_token_count=contextual_pruning_config.tool_num_tokens, ) return _apply_pruning( sections=sections, section_relevance_list=section_relevance_list, token_limit=token_limit, - is_manually_selected_docs=document_pruning_config.is_manually_selected_docs, - use_sections=document_pruning_config.use_sections, # Now default True - using_tool_message=document_pruning_config.using_tool_message, + is_manually_selected_docs=contextual_pruning_config.is_manually_selected_docs, + use_sections=contextual_pruning_config.use_sections, # Now default True + using_tool_message=contextual_pruning_config.using_tool_message, llm_config=llm_config, ) @@ -360,7 +367,7 @@ def prune_and_merge_sections( prompt_config: PromptConfig, llm_config: LLMConfig, question: str, - document_pruning_config: DocumentPruningConfig, + contextual_pruning_config: ContextualPruningConfig, ) -> list[InferenceSection]: # Assumes the sections are score ordered with highest first remaining_sections = prune_sections( @@ -369,7 +376,7 @@ def prune_and_merge_sections( prompt_config=prompt_config, llm_config=llm_config, question=question, - document_pruning_config=document_pruning_config, + contextual_pruning_config=contextual_pruning_config, ) merged_sections = _merge_sections(sections=remaining_sections) diff --git a/backend/danswer/llm/utils.py b/backend/danswer/llm/utils.py index 73d482f354dc..4172a1f5e611 100644 --- a/backend/danswer/llm/utils.py +++ b/backend/danswer/llm/utils.py @@ -32,6 +32,7 @@ from litellm.exceptions import UnprocessableEntityError # type: ignore from danswer.configs.constants import MessageType from danswer.configs.model_configs import GEN_AI_MAX_OUTPUT_TOKENS from danswer.configs.model_configs import GEN_AI_MAX_TOKENS +from danswer.configs.model_configs import GEN_AI_MODEL_DEFAULT_MAX_TOKENS from danswer.configs.model_configs import GEN_AI_MODEL_PROVIDER from danswer.db.models import ChatMessage from danswer.file_store.models import ChatFileType @@ -353,9 +354,9 @@ def get_llm_max_tokens( raise RuntimeError("No max tokens found for LLM") except Exception: logger.exception( - f"Failed to get max tokens for LLM with name {model_name}. Defaulting to 4096." + f"Failed to get max tokens for LLM with name {model_name}. Defaulting to {GEN_AI_MODEL_DEFAULT_MAX_TOKENS}." ) - return 4096 + return GEN_AI_MODEL_DEFAULT_MAX_TOKENS def get_max_input_tokens( diff --git a/backend/danswer/search/models.py b/backend/danswer/search/models.py index d7027b929a9f..85cb769633bd 100644 --- a/backend/danswer/search/models.py +++ b/backend/danswer/search/models.py @@ -4,8 +4,6 @@ from typing import Any from pydantic import BaseModel from pydantic import validator -from danswer.configs.chat_configs import CONTEXT_CHUNKS_ABOVE -from danswer.configs.chat_configs import CONTEXT_CHUNKS_BELOW from danswer.configs.chat_configs import NUM_RETURNED_HITS from danswer.configs.constants import DocumentSource from danswer.db.models import Persona @@ -73,15 +71,15 @@ class ChunkMetric(BaseModel): class ChunkContext(BaseModel): - # Additional surrounding context options, if full doc, then chunks are deduped - # If surrounding context overlap, it is combined into one - chunks_above: int = CONTEXT_CHUNKS_ABOVE - chunks_below: int = CONTEXT_CHUNKS_BELOW + # If not specified (None), picked up from Persona settings if there is space + # if specified (even if 0), it always uses the specified number of chunks above and below + chunks_above: int | None = None + chunks_below: int | None = None full_doc: bool = False @validator("chunks_above", "chunks_below", pre=True, each_item=False) def check_non_negative(cls, value: int, field: Any) -> int: - if value < 0: + if value is not None and value < 0: raise ValueError(f"{field.name} must be non-negative") return value @@ -117,6 +115,10 @@ class SearchQuery(ChunkContext): evaluation_type: LLMEvaluationType filters: IndexFilters + # by this point, the chunks_above and chunks_below must be set + chunks_above: int + chunks_below: int + rerank_settings: RerankingDetails | None hybrid_alpha: float recency_bias_multiplier: float diff --git a/backend/danswer/search/pipeline.py b/backend/danswer/search/pipeline.py index 52e046153b9d..f93ab546ce80 100644 --- a/backend/danswer/search/pipeline.py +++ b/backend/danswer/search/pipeline.py @@ -11,7 +11,6 @@ from danswer.db.embedding_model import get_current_db_embedding_model from danswer.db.models import User from danswer.document_index.factory import get_default_document_index from danswer.document_index.interfaces import VespaChunkRequest -from danswer.llm.answering.models import DocumentPruningConfig from danswer.llm.answering.models import PromptConfig from danswer.llm.answering.prune_and_merge import _merge_sections from danswer.llm.answering.prune_and_merge import ChunkRange @@ -56,7 +55,6 @@ class SearchPipeline: ) = None, rerank_metrics_callback: Callable[[RerankMetricsContainer], None] | None = None, prompt_config: PromptConfig | None = None, - pruning_config: DocumentPruningConfig | None = None, ): self.search_request = search_request self.user = user @@ -73,7 +71,6 @@ class SearchPipeline: secondary_index_name=None, ) self.prompt_config: PromptConfig | None = prompt_config - self.pruning_config: DocumentPruningConfig | None = pruning_config # Preprocessing steps generate this self._search_query: SearchQuery | None = None @@ -139,10 +136,6 @@ class SearchPipeline: """Retrieval and Postprocessing""" def _get_chunks(self) -> list[InferenceChunk]: - """TODO as a future extension: - If large chunks (above 512 tokens) are used which cannot be directly fed to the LLM, - This step should run the two retrievals to get all of the base size chunks - """ if self._retrieved_chunks is not None: return self._retrieved_chunks @@ -178,7 +171,6 @@ class SearchPipeline: chunk_requests: list[VespaChunkRequest] = [] # Full doc setting takes priority - if self.search_query.full_doc: seen_document_ids = set() diff --git a/backend/danswer/search/preprocessing/preprocessing.py b/backend/danswer/search/preprocessing/preprocessing.py index ac5f87024307..5ecee9898337 100644 --- a/backend/danswer/search/preprocessing/preprocessing.py +++ b/backend/danswer/search/preprocessing/preprocessing.py @@ -1,6 +1,8 @@ from sqlalchemy.orm import Session from danswer.configs.chat_configs import BASE_RECENCY_DECAY +from danswer.configs.chat_configs import CONTEXT_CHUNKS_ABOVE +from danswer.configs.chat_configs import CONTEXT_CHUNKS_BELOW from danswer.configs.chat_configs import DISABLE_LLM_DOC_RELEVANCE from danswer.configs.chat_configs import FAVOR_RECENT_DECAY_MULTIPLIER from danswer.configs.chat_configs import HYBRID_ALPHA @@ -199,6 +201,20 @@ def retrieval_preprocessing( if search_request.hybrid_alpha: hybrid_alpha = search_request.hybrid_alpha + # Search request overrides anything else as it's explicitly set by the request + # If not explicitly specified, use the persona settings if they exist + # Otherwise, use the global defaults + chunks_above = ( + search_request.chunks_above + if search_request.chunks_above is not None + else (persona.chunks_above if persona else CONTEXT_CHUNKS_ABOVE) + ) + chunks_below = ( + search_request.chunks_below + if search_request.chunks_below is not None + else (persona.chunks_below if persona else CONTEXT_CHUNKS_BELOW) + ) + return SearchQuery( query=query, processed_keywords=processed_keywords, @@ -216,7 +232,7 @@ def retrieval_preprocessing( max_llm_filter_sections=rerank_settings.num_rerank if rerank_settings else NUM_POSTPROCESSED_RESULTS, - chunks_above=search_request.chunks_above, - chunks_below=search_request.chunks_below, + chunks_above=chunks_above, + chunks_below=chunks_below, full_doc=search_request.full_doc, ) diff --git a/backend/danswer/search/retrieval/search_runner.py b/backend/danswer/search/retrieval/search_runner.py index 669851ba8bbf..cffffa09145e 100644 --- a/backend/danswer/search/retrieval/search_runner.py +++ b/backend/danswer/search/retrieval/search_runner.py @@ -266,7 +266,7 @@ def retrieve_chunks( if not top_chunks: logger.warning( - f"{query.search_type.value.capitalize()} search returned no results " + f"Hybrid ({query.search_type.value.capitalize()}) search returned no results " f"with filters: {query.filters}" ) return [] diff --git a/backend/danswer/tools/search/search_tool.py b/backend/danswer/tools/search/search_tool.py index 5a44f3761085..4ec6ac050e3a 100644 --- a/backend/danswer/tools/search/search_tool.py +++ b/backend/danswer/tools/search/search_tool.py @@ -11,12 +11,17 @@ from danswer.chat.models import DanswerContext from danswer.chat.models import DanswerContexts from danswer.chat.models import LlmDoc from danswer.chat.models import SectionRelevancePiece +from danswer.configs.chat_configs import CONTEXT_CHUNKS_ABOVE +from danswer.configs.chat_configs import CONTEXT_CHUNKS_BELOW +from danswer.configs.model_configs import GEN_AI_MODEL_DEFAULT_MAX_TOKENS from danswer.db.models import Persona from danswer.db.models import User from danswer.dynamic_configs.interface import JSON_ro +from danswer.llm.answering.models import ContextualPruningConfig from danswer.llm.answering.models import DocumentPruningConfig from danswer.llm.answering.models import PreviousMessage from danswer.llm.answering.models import PromptConfig +from danswer.llm.answering.prompts.citations_prompt import compute_max_llm_input_tokens from danswer.llm.answering.prune_and_merge import prune_and_merge_sections from danswer.llm.answering.prune_and_merge import prune_sections from danswer.llm.interfaces import LLM @@ -84,8 +89,8 @@ class SearchTool(Tool): # if specified, will not actually run a search and will instead return these # sections. Used when the user selects specific docs to talk to selected_sections: list[InferenceSection] | None = None, - chunks_above: int = 0, - chunks_below: int = 0, + chunks_above: int | None = None, + chunks_below: int | None = None, full_doc: bool = False, bypass_acl: bool = False, ) -> None: @@ -95,17 +100,48 @@ class SearchTool(Tool): self.prompt_config = prompt_config self.llm = llm self.fast_llm = fast_llm - self.pruning_config = pruning_config self.evaluation_type = evaluation_type self.selected_sections = selected_sections - self.chunks_above = chunks_above - self.chunks_below = chunks_below self.full_doc = full_doc self.bypass_acl = bypass_acl self.db_session = db_session + self.chunks_above = ( + chunks_above + if chunks_above is not None + else ( + persona.chunks_above + if persona.chunks_above is not None + else CONTEXT_CHUNKS_ABOVE + ) + ) + self.chunks_below = ( + chunks_below + if chunks_below is not None + else ( + persona.chunks_below + if persona.chunks_below is not None + else CONTEXT_CHUNKS_BELOW + ) + ) + + # For small context models, don't include additional surrounding context + # The 3 here for at least minimum 1 above, 1 below and 1 for the middle chunk + max_llm_tokens = compute_max_llm_input_tokens(self.llm.config) + if max_llm_tokens < 3 * GEN_AI_MODEL_DEFAULT_MAX_TOKENS: + self.chunks_above = 0 + self.chunks_below = 0 + + num_chunk_multiple = self.chunks_above + self.chunks_below + 1 + + self.contextual_pruning_config = ( + ContextualPruningConfig.from_doc_pruning_config( + num_chunk_multiple=num_chunk_multiple, doc_pruning_config=pruning_config + ) + ) + @property def name(self) -> str: return self._NAME @@ -216,7 +252,7 @@ class SearchTool(Tool): prompt_config=self.prompt_config, llm_config=self.llm.config, question=query, - document_pruning_config=self.pruning_config, + contextual_pruning_config=self.contextual_pruning_config, ) llm_docs = [ @@ -260,7 +296,6 @@ class SearchTool(Tool): bypass_acl=self.bypass_acl, db_session=self.db_session, prompt_config=self.prompt_config, - pruning_config=self.pruning_config, ) yield ToolResponse( @@ -301,7 +336,7 @@ class SearchTool(Tool): prompt_config=self.prompt_config, llm_config=self.llm.config, question=query, - document_pruning_config=self.pruning_config, + contextual_pruning_config=self.contextual_pruning_config, ) llm_docs = [ diff --git a/backend/ee/danswer/server/query_and_chat/chat_backend.py b/backend/ee/danswer/server/query_and_chat/chat_backend.py index 5333c91fc86f..7661b7781edc 100644 --- a/backend/ee/danswer/server/query_and_chat/chat_backend.py +++ b/backend/ee/danswer/server/query_and_chat/chat_backend.py @@ -106,8 +106,9 @@ def handle_simplified_chat_message( search_doc_ids=chat_message_req.search_doc_ids, retrieval_options=retrieval_options, query_override=chat_message_req.query_override, - chunks_above=chat_message_req.chunks_above, - chunks_below=chat_message_req.chunks_below, + # Currently only applies to search flow not chat + chunks_above=0, + chunks_below=0, full_doc=chat_message_req.full_doc, ) @@ -232,8 +233,8 @@ def handle_send_message_simple_with_history( search_doc_ids=None, retrieval_options=req.retrieval_options, query_override=rephrased_query, - chunks_above=req.chunks_above, - chunks_below=req.chunks_below, + chunks_above=0, + chunks_below=0, full_doc=req.full_doc, ) diff --git a/backend/ee/danswer/server/query_and_chat/query_backend.py b/backend/ee/danswer/server/query_and_chat/query_backend.py index f6599abea067..aef3648220e4 100644 --- a/backend/ee/danswer/server/query_and_chat/query_backend.py +++ b/backend/ee/danswer/server/query_and_chat/query_backend.py @@ -54,6 +54,7 @@ def handle_search_request( logger.notice(f"Received document search query: {query}") llm, fast_llm = get_default_llms() + search_pipeline = SearchPipeline( search_request=SearchRequest( query=query,