Allowing users to set Search Settings (#2106)

This commit is contained in:
Yuhong Sun 2024-08-10 20:48:58 -07:00 committed by GitHub
parent 7358ece008
commit d60fb15ad3
No known key found for this signature in database
GPG Key ID: B5690EEEBB952194
30 changed files with 281 additions and 115 deletions

View File

@ -70,7 +70,7 @@
"cwd": "${workspaceFolder}/backend",
"envFile": "${workspaceFolder}/.env",
"env": {
"ENABLE_MINI_CHUNK": "false",
"ENABLE_MULTIPASS_INDEXING": "false",
"LOG_LEVEL": "DEBUG",
"PYTHONUNBUFFERED": "1",
"PYTHONPATH": "."

View File

@ -264,7 +264,9 @@ NUM_SECONDARY_INDEXING_WORKERS = int(
os.environ.get("NUM_SECONDARY_INDEXING_WORKERS") or NUM_INDEXING_WORKERS
)
# More accurate results at the expense of indexing speed and index size (stores additional 4 MINI_CHUNK vectors)
ENABLE_MINI_CHUNK = os.environ.get("ENABLE_MINI_CHUNK", "").lower() == "true"
ENABLE_MULTIPASS_INDEXING = (
os.environ.get("ENABLE_MULTIPASS_INDEXING", "").lower() == "true"
)
# Finer grained chunking for more detail retention
# Slightly larger since the sentence aware split is a max cutoff so most minichunks will be under MINI_CHUNK_SIZE
# tokens. But we need it to be at least as big as 1/4th chunk size to avoid having a tiny mini-chunk at the end

View File

@ -9,7 +9,7 @@ NUM_RETURNED_HITS = 50
# Used for LLM filtering and reranking
# We want this to be approximately the number of results we want to show on the first page
# It cannot be too large due to cost and latency implications
NUM_RERANKED_RESULTS = 20
NUM_POSTPROCESSED_RESULTS = 20
# May be less depending on model
MAX_CHUNKS_FED_TO_CHAT = float(os.environ.get("MAX_CHUNKS_FED_TO_CHAT") or 10.0)
@ -43,11 +43,6 @@ DISABLE_LLM_QUERY_REPHRASE = (
# 1 edit per 20 characters, currently unused due to fuzzy match being too slow
QUOTE_ALLOWED_ERROR_PERCENT = 0.05
QA_TIMEOUT = int(os.environ.get("QA_TIMEOUT") or "60") # 60 seconds
# The keyword token classifier model and NLTK are both english facing, if using multilingual, just skip this
if os.environ.get("EDIT_KEYWORD_QUERY"):
EDIT_KEYWORD_QUERY = os.environ.get("EDIT_KEYWORD_QUERY", "").lower() == "true"
else:
EDIT_KEYWORD_QUERY = not os.environ.get("MULTILINGUAL_QUERY_EXPANSION")
# Weighting factor between Vector and Keyword Search, 1 for completely vector search
HYBRID_ALPHA = max(0, min(1, float(os.environ.get("HYBRID_ALPHA") or 0.62)))
HYBRID_ALPHA_KEYWORD = max(

View File

@ -75,6 +75,7 @@ UNNAMED_KEY_PLACEHOLDER = "Unnamed"
# Key-Value store keys
KV_REINDEX_KEY = "needs_reindexing"
KV_SEARCH_SETTINGS = "search_settings"
KV_USER_STORE_KEY = "INVITED_USERS"
KV_NO_AUTH_USER_PREFERENCES_KEY = "no_auth_user_preferences"
KV_CRED_KEY = "credential_id_{}"

View File

@ -50,7 +50,7 @@ from danswer.one_shot_answer.models import OneShotQAResponse
from danswer.search.enums import OptionalSearchSetting
from danswer.search.models import BaseFilters
from danswer.search.models import RetrievalDetails
from shared_configs.configs import ENABLE_RERANKING_ASYNC_FLOW
from danswer.search.search_settings import get_search_settings
srl = SlackRateLimiter()
@ -223,15 +223,23 @@ def handle_regular_answer(
enable_auto_detect_filters=auto_detect_filters,
)
# Always apply reranking settings if it exists, this is the non-streaming flow
saved_search_settings = get_search_settings()
# This includes throwing out answer via reflexion
answer = _get_answer(
DirectQARequest(
messages=messages,
multilingual_query_expansion=saved_search_settings.multilingual_expansion
if saved_search_settings
else None,
prompt_id=prompt.id if prompt else None,
persona_id=persona.id if persona is not None else 0,
retrieval_options=retrieval_details,
chain_of_thought=not disable_cot,
skip_rerank=not ENABLE_RERANKING_ASYNC_FLOW,
rerank_settings=saved_search_settings.to_reranking_detail()
if saved_search_settings
else None,
)
)
except Exception as e:

View File

@ -4,7 +4,6 @@ from typing import Optional
from typing import TYPE_CHECKING
from danswer.configs.app_configs import BLURB_SIZE
from danswer.configs.app_configs import ENABLE_MINI_CHUNK
from danswer.configs.app_configs import MINI_CHUNK_SIZE
from danswer.configs.app_configs import SKIP_METADATA_IN_CHUNK
from danswer.configs.constants import DocumentSource
@ -15,7 +14,6 @@ from danswer.connectors.cross_connector_utils.miscellaneous_utils import (
get_metadata_keys_to_ignore,
)
from danswer.connectors.models import Document
from danswer.indexing.embedder import IndexingEmbedder
from danswer.indexing.models import DocAwareChunk
from danswer.natural_language_processing.utils import get_tokenizer
from danswer.utils.logger import setup_logger
@ -124,19 +122,20 @@ def _get_metadata_suffix_for_document_index(
def chunk_document(
document: Document,
embedder: IndexingEmbedder,
model_name: str,
provider_type: str | None,
enable_multipass: bool,
chunk_tok_size: int = DOC_EMBEDDING_CONTEXT_SIZE,
subsection_overlap: int = CHUNK_OVERLAP,
blurb_size: int = BLURB_SIZE, # Used for both title and content
include_metadata: bool = not SKIP_METADATA_IN_CHUNK,
mini_chunk_size: int = MINI_CHUNK_SIZE,
enable_mini_chunk: bool = ENABLE_MINI_CHUNK,
) -> list[DocAwareChunk]:
from llama_index.text_splitter import SentenceSplitter
tokenizer = get_tokenizer(
model_name=embedder.model_name,
provider_type=embedder.provider_type,
model_name=model_name,
provider_type=provider_type,
)
blurb_splitter = SentenceSplitter(
@ -212,7 +211,7 @@ def chunk_document(
metadata_suffix_semantic=metadata_suffix_semantic,
metadata_suffix_keyword=metadata_suffix_keyword,
mini_chunk_texts=mini_chunk_splitter.split_text(chunk_text)
if enable_mini_chunk and chunk_text.strip()
if enable_multipass and chunk_text.strip()
else None,
)
)
@ -226,7 +225,7 @@ def chunk_document(
start_chunk_id=len(chunks),
chunk_splitter=chunk_splitter,
mini_chunk_splitter=mini_chunk_splitter
if enable_mini_chunk and chunk_text.strip()
if enable_multipass and chunk_text.strip()
else None,
blurb=extract_blurb(section_text, blurb_splitter),
title_prefix=title_prefix,
@ -260,7 +259,7 @@ def chunk_document(
metadata_suffix_semantic=metadata_suffix_semantic,
metadata_suffix_keyword=metadata_suffix_keyword,
mini_chunk_texts=mini_chunk_splitter.split_text(chunk_text)
if enable_mini_chunk and chunk_text.strip()
if enable_multipass and chunk_text.strip()
else None,
)
)
@ -282,7 +281,7 @@ def chunk_document(
metadata_suffix_semantic=metadata_suffix_semantic,
metadata_suffix_keyword=metadata_suffix_keyword,
mini_chunk_texts=mini_chunk_splitter.split_text(chunk_text)
if enable_mini_chunk and chunk_text.strip()
if enable_multipass and chunk_text.strip()
else None,
)
)
@ -296,18 +295,28 @@ class Chunker:
def chunk(
self,
document: Document,
embedder: IndexingEmbedder,
) -> list[DocAwareChunk]:
raise NotImplementedError
class DefaultChunker(Chunker):
def __init__(
self, model_name: str, provider_type: str | None, enable_multipass: bool
):
self.model_name = model_name
self.provider_type = provider_type
self.enable_multipass = enable_multipass
def chunk(
self,
document: Document,
embedder: IndexingEmbedder,
) -> list[DocAwareChunk]:
# Specifically for reproducing an issue with gmail
if document.source == DocumentSource.GMAIL:
logger.debug(f"Chunking {document.semantic_identifier}")
return chunk_document(document, embedder=embedder)
return chunk_document(
document=document,
model_name=self.model_name,
provider_type=self.provider_type,
enable_multipass=self.enable_multipass,
)

View File

@ -4,6 +4,7 @@ from typing import Protocol
from sqlalchemy.orm import Session
from danswer.access.access import get_access_for_documents
from danswer.configs.app_configs import ENABLE_MULTIPASS_INDEXING
from danswer.configs.constants import DEFAULT_BOOST
from danswer.connectors.cross_connector_utils.miscellaneous_utils import (
get_experts_stores_representations,
@ -25,6 +26,7 @@ from danswer.indexing.chunker import DefaultChunker
from danswer.indexing.embedder import IndexingEmbedder
from danswer.indexing.models import DocAwareChunk
from danswer.indexing.models import DocMetadataAwareIndexChunk
from danswer.search.search_settings import get_search_settings
from danswer.utils.logger import setup_logger
from danswer.utils.timing import log_function_time
@ -181,11 +183,10 @@ def index_doc_batch(
)
logger.debug("Starting chunking")
# The embedder is needed here to get the correct tokenizer
chunks: list[DocAwareChunk] = [
chunk
for document in updatable_docs
for chunk in chunker.chunk(document=document, embedder=embedder)
for chunk in chunker.chunk(document=document)
]
logger.debug("Starting embedding")
@ -267,7 +268,17 @@ def build_indexing_pipeline(
ignore_time_skip: bool = False,
) -> IndexingPipelineProtocol:
"""Builds a pipline which takes in a list (batch) of docs and indexes them."""
chunker = chunker or DefaultChunker()
search_settings = get_search_settings()
multipass = (
search_settings.multipass_indexing
if search_settings
else ENABLE_MULTIPASS_INDEXING
)
chunker = chunker or DefaultChunker(
model_name=embedder.model_name,
provider_type=embedder.provider_type,
enable_multipass=multipass,
)
return partial(
index_doc_batch,

View File

@ -2,7 +2,6 @@ from langchain.schema.messages import HumanMessage
from langchain.schema.messages import SystemMessage
from danswer.chat.models import LlmDoc
from danswer.configs.chat_configs import MULTILINGUAL_QUERY_EXPANSION
from danswer.configs.model_configs import GEN_AI_SINGLE_USER_MESSAGE_EXPECTED_MAX_TOKENS
from danswer.db.models import Persona
from danswer.db.persona import get_default_prompt__read_only
@ -29,17 +28,19 @@ from danswer.prompts.token_counts import CITATION_REMINDER_TOKEN_CNT
from danswer.prompts.token_counts import CITATION_STATEMENT_TOKEN_CNT
from danswer.prompts.token_counts import LANGUAGE_HINT_TOKEN_CNT
from danswer.search.models import InferenceChunk
from danswer.search.search_settings import get_multilingual_expansion
def get_prompt_tokens(prompt_config: PromptConfig) -> int:
# Note: currently custom prompts do not allow datetime aware, only default prompts
multilingual_expansion = get_multilingual_expansion()
return (
check_number_of_tokens(prompt_config.system_prompt)
+ check_number_of_tokens(prompt_config.task_prompt)
+ CHAT_USER_PROMPT_WITH_CONTEXT_OVERHEAD_TOKEN_CNT
+ CITATION_STATEMENT_TOKEN_CNT
+ CITATION_REMINDER_TOKEN_CNT
+ (LANGUAGE_HINT_TOKEN_CNT if bool(MULTILINGUAL_QUERY_EXPANSION) else 0)
+ (LANGUAGE_HINT_TOKEN_CNT if multilingual_expansion else 0)
+ (ADDITIONAL_INFO_TOKEN_CNT if prompt_config.datetime_aware else 0)
)
@ -135,7 +136,10 @@ def build_citations_user_message(
all_doc_useful: bool,
history_message: str = "",
) -> HumanMessage:
task_prompt_with_reminder = build_task_prompt_reminders(prompt_config)
multilingual_expansion = get_multilingual_expansion()
task_prompt_with_reminder = build_task_prompt_reminders(
prompt=prompt_config, use_language_hint=bool(multilingual_expansion)
)
if context_docs:
context_docs_str = build_complete_context_str(context_docs)

View File

@ -2,7 +2,6 @@ from langchain.schema.messages import HumanMessage
from danswer.chat.models import LlmDoc
from danswer.configs.chat_configs import LANGUAGE_HINT
from danswer.configs.chat_configs import MULTILINGUAL_QUERY_EXPANSION
from danswer.configs.chat_configs import QA_PROMPT_OVERRIDE
from danswer.llm.answering.models import PromptConfig
from danswer.prompts.direct_qa_prompts import CONTEXT_BLOCK
@ -12,6 +11,7 @@ from danswer.prompts.direct_qa_prompts import WEAK_LLM_PROMPT
from danswer.prompts.prompt_utils import add_date_time_to_prompt
from danswer.prompts.prompt_utils import build_complete_context_str
from danswer.search.models import InferenceChunk
from danswer.search.search_settings import get_search_settings
def _build_weak_llm_quotes_prompt(
@ -19,7 +19,6 @@ def _build_weak_llm_quotes_prompt(
context_docs: list[LlmDoc] | list[InferenceChunk],
history_str: str,
prompt: PromptConfig,
use_language_hint: bool,
) -> HumanMessage:
"""Since Danswer supports a variety of LLMs, this less demanding prompt is provided
as an option to use with weaker LLMs such as small version, low float precision, quantized,
@ -48,8 +47,12 @@ def _build_strong_llm_quotes_prompt(
context_docs: list[LlmDoc] | list[InferenceChunk],
history_str: str,
prompt: PromptConfig,
use_language_hint: bool,
) -> HumanMessage:
search_settings = get_search_settings()
use_language_hint = (
bool(search_settings.multilingual_expansion) if search_settings else False
)
context_block = ""
if context_docs:
context_docs_str = build_complete_context_str(context_docs)
@ -79,7 +82,6 @@ def build_quotes_user_message(
context_docs: list[LlmDoc] | list[InferenceChunk],
history_str: str,
prompt: PromptConfig,
use_language_hint: bool = bool(MULTILINGUAL_QUERY_EXPANSION),
) -> HumanMessage:
prompt_builder = (
_build_weak_llm_quotes_prompt
@ -92,7 +94,6 @@ def build_quotes_user_message(
context_docs=context_docs,
history_str=history_str,
prompt=prompt,
use_language_hint=use_language_hint,
)
@ -101,7 +102,6 @@ def build_quotes_prompt(
context_docs: list[LlmDoc] | list[InferenceChunk],
history_str: str,
prompt: PromptConfig,
use_language_hint: bool = bool(MULTILINGUAL_QUERY_EXPANSION),
) -> HumanMessage:
prompt_builder = (
_build_weak_llm_quotes_prompt
@ -114,5 +114,4 @@ def build_quotes_prompt(
context_docs=context_docs,
history_str=history_str,
prompt=prompt,
use_language_hint=use_language_hint,
)

View File

@ -27,12 +27,14 @@ from danswer.configs.app_configs import APP_PORT
from danswer.configs.app_configs import AUTH_TYPE
from danswer.configs.app_configs import DISABLE_GENERATIVE_AI
from danswer.configs.app_configs import DISABLE_INDEX_UPDATE_ON_SWAP
from danswer.configs.app_configs import ENABLE_MULTIPASS_INDEXING
from danswer.configs.app_configs import LOG_ENDPOINT_LATENCY
from danswer.configs.app_configs import OAUTH_CLIENT_ID
from danswer.configs.app_configs import OAUTH_CLIENT_SECRET
from danswer.configs.app_configs import USER_AUTH_SECRET
from danswer.configs.app_configs import WEB_DOMAIN
from danswer.configs.chat_configs import MULTILINGUAL_QUERY_EXPANSION
from danswer.configs.chat_configs import NUM_POSTPROCESSED_RESULTS
from danswer.configs.constants import AuthType
from danswer.configs.constants import KV_REINDEX_KEY
from danswer.configs.constants import POSTGRES_WEB_APP_NAME
@ -58,7 +60,10 @@ from danswer.dynamic_configs.factory import get_dynamic_config_store
from danswer.dynamic_configs.interface import ConfigNotFoundError
from danswer.llm.llm_initialization import load_llm_providers
from danswer.natural_language_processing.search_nlp_models import warm_up_encoders
from danswer.search.models import SavedSearchSettings
from danswer.search.retrieval.search_runner import download_nltk_data
from danswer.search.search_settings import get_search_settings
from danswer.search.search_settings import update_search_settings
from danswer.server.auth_check import check_router_auth
from danswer.server.danswer_api.ingestion import router as danswer_api_router
from danswer.server.documents.cc_pair import router as cc_pair_router
@ -83,7 +88,7 @@ from danswer.server.manage.embedding.api import basic_router as embedding_router
from danswer.server.manage.get_state import router as state_router
from danswer.server.manage.llm.api import admin_router as llm_admin_router
from danswer.server.manage.llm.api import basic_router as llm_router
from danswer.server.manage.secondary_index import router as secondary_index_router
from danswer.server.manage.search_settings import router as search_settings_router
from danswer.server.manage.slack_bot import router as slack_bot_management_router
from danswer.server.manage.standard_answer import router as standard_answer_router
from danswer.server.manage.users import router as user_router
@ -107,6 +112,9 @@ from danswer.utils.telemetry import RecordType
from danswer.utils.variable_functionality import fetch_versioned_implementation
from danswer.utils.variable_functionality import global_version
from danswer.utils.variable_functionality import set_is_ee_based_on_env_variable
from shared_configs.configs import DEFAULT_CROSS_ENCODER_API_KEY
from shared_configs.configs import DEFAULT_CROSS_ENCODER_MODEL_NAME
from shared_configs.configs import ENABLE_RERANKING_ASYNC_FLOW
from shared_configs.configs import ENABLE_RERANKING_REAL_TIME_FLOW
from shared_configs.configs import MODEL_SERVER_HOST
from shared_configs.configs import MODEL_SERVER_PORT
@ -243,11 +251,6 @@ async def lifespan(app: FastAPI) -> AsyncGenerator:
if DISABLE_GENERATIVE_AI:
logger.info("Generative AI Q&A disabled")
if MULTILINGUAL_QUERY_EXPANSION:
logger.info(
f"Using multilingual flow with languages: {MULTILINGUAL_QUERY_EXPANSION}"
)
# fill up Postgres connection pools
await warm_up_connections()
@ -275,8 +278,37 @@ async def lifespan(app: FastAPI) -> AsyncGenerator:
f'Passage embedding prefix: "{db_embedding_model.passage_prefix}"'
)
if ENABLE_RERANKING_REAL_TIME_FLOW:
logger.info("Reranking step of search flow is enabled.")
search_settings = get_search_settings()
if search_settings:
if not search_settings.disable_rerank_for_streaming:
logger.info("Reranking is enabled.")
if search_settings.multilingual_expansion:
logger.info(
f"Multilingual query expansion is enabled with {search_settings.multilingual_expansion}."
)
else:
if ENABLE_RERANKING_REAL_TIME_FLOW or ENABLE_RERANKING_ASYNC_FLOW:
logger.info("Reranking is enabled.")
if not DEFAULT_CROSS_ENCODER_MODEL_NAME:
raise ValueError("No reranking model specified.")
update_search_settings(
SavedSearchSettings(
rerank_model_name=DEFAULT_CROSS_ENCODER_MODEL_NAME,
api_key=DEFAULT_CROSS_ENCODER_API_KEY,
disable_rerank_for_streaming=not ENABLE_RERANKING_REAL_TIME_FLOW,
num_rerank=NUM_POSTPROCESSED_RESULTS,
multilingual_expansion=[
s.strip()
for s in MULTILINGUAL_QUERY_EXPANSION.split(",")
if s.strip()
]
if MULTILINGUAL_QUERY_EXPANSION
else [],
multipass_indexing=ENABLE_MULTIPASS_INDEXING,
)
)
logger.info("Verifying query preprocessing (NLTK) data is downloaded")
download_nltk_data()
@ -326,7 +358,7 @@ def get_application() -> FastAPI:
include_router_with_global_prefix_prepended(application, cc_pair_router)
include_router_with_global_prefix_prepended(application, folder_router)
include_router_with_global_prefix_prepended(application, document_set_router)
include_router_with_global_prefix_prepended(application, secondary_index_router)
include_router_with_global_prefix_prepended(application, search_settings_router)
include_router_with_global_prefix_prepended(
application, slack_bot_management_router
)

View File

@ -11,6 +11,7 @@ from danswer.chat.models import QADocsResponse
from danswer.configs.constants import MessageType
from danswer.search.enums import LLMEvaluationType
from danswer.search.models import ChunkContext
from danswer.search.models import RerankingDetails
from danswer.search.models import RetrievalDetails
@ -28,9 +29,9 @@ class DirectQARequest(ChunkContext):
messages: list[ThreadMessage]
prompt_id: int | None
persona_id: int
multilingual_query_expansion: list[str] | None = None
retrieval_options: RetrievalDetails = Field(default_factory=RetrievalDetails)
# This is to forcibly skip (or run) the step, if None it uses the system defaults
skip_rerank: bool | None = None
rerank_settings: RerankingDetails | None = None
evaluation_type: LLMEvaluationType = LLMEvaluationType.UNSPECIFIED
chain_of_thought: bool = False

View File

@ -6,7 +6,6 @@ from langchain_core.messages import BaseMessage
from danswer.chat.models import LlmDoc
from danswer.configs.chat_configs import LANGUAGE_HINT
from danswer.configs.chat_configs import MULTILINGUAL_QUERY_EXPANSION
from danswer.configs.constants import DocumentSource
from danswer.db.models import Prompt
from danswer.llm.answering.models import PromptConfig
@ -56,7 +55,7 @@ def add_date_time_to_prompt(prompt_str: str) -> str:
def build_task_prompt_reminders(
prompt: Prompt | PromptConfig,
use_language_hint: bool = bool(MULTILINGUAL_QUERY_EXPANSION),
use_language_hint: bool,
citation_str: str = CITATION_REMINDER,
language_hint_str: str = LANGUAGE_HINT,
) -> str:

View File

@ -6,7 +6,6 @@ from pydantic import validator
from danswer.configs.chat_configs import CONTEXT_CHUNKS_ABOVE
from danswer.configs.chat_configs import CONTEXT_CHUNKS_BELOW
from danswer.configs.chat_configs import NUM_RERANKED_RESULTS
from danswer.configs.chat_configs import NUM_RETURNED_HITS
from danswer.configs.constants import DocumentSource
from danswer.db.models import Persona
@ -14,7 +13,6 @@ from danswer.indexing.models import BaseChunk
from danswer.search.enums import LLMEvaluationType
from danswer.search.enums import OptionalSearchSetting
from danswer.search.enums import SearchType
from shared_configs.configs import ENABLE_RERANKING_REAL_TIME_FLOW
MAX_METRICS_CONTENT = (
@ -22,6 +20,32 @@ MAX_METRICS_CONTENT = (
)
class RerankingDetails(BaseModel):
rerank_model_name: str
api_key: str | None
# Set to 0 to disable reranking explicitly
num_rerank: int
class SavedSearchSettings(RerankingDetails):
# Empty for no additional expansion
multilingual_expansion: list[str]
# Encompasses both mini and large chunks
multipass_indexing: bool
# For faster flows where the results should start immediately
# this more time intensive step can be skipped
disable_rerank_for_streaming: bool
def to_reranking_detail(self) -> RerankingDetails:
return RerankingDetails(
rerank_model_name=self.rerank_model_name,
api_key=self.api_key,
num_rerank=self.num_rerank,
)
class Tag(BaseModel):
tag_key: str
tag_value: str
@ -60,8 +84,6 @@ class ChunkContext(BaseModel):
class SearchRequest(ChunkContext):
"""Input to the SearchPipeline."""
query: str
search_type: SearchType = SearchType.SEMANTIC
@ -74,10 +96,10 @@ class SearchRequest(ChunkContext):
offset: int | None = None
limit: int | None = None
multilingual_expansion: list[str] | None = None
recency_bias_multiplier: float = 1.0
hybrid_alpha: float | None = None
# This is to forcibly skip (or run) the step, if None it uses the system defaults
skip_rerank: bool | None = None
rerank_settings: RerankingDetails | None = None
evaluation_type: LLMEvaluationType = LLMEvaluationType.UNSPECIFIED
class Config:
@ -85,24 +107,23 @@ class SearchRequest(ChunkContext):
class SearchQuery(ChunkContext):
"Processed Request that is directly passed to the SearchPipeline"
query: str
processed_keywords: list[str]
search_type: SearchType
evaluation_type: LLMEvaluationType
filters: IndexFilters
rerank_settings: RerankingDetails | None
hybrid_alpha: float
recency_bias_multiplier: float
# Only used if LLM evaluation type is not skip, None to use default settings
max_llm_filter_sections: int
num_hits: int = NUM_RETURNED_HITS
offset: int = 0
skip_rerank: bool = not ENABLE_RERANKING_REAL_TIME_FLOW
# Only used if not skip_rerank
num_rerank: int | None = NUM_RERANKED_RESULTS
# Only used if not skip_llm_chunk_filter
max_llm_filter_sections: int = NUM_RERANKED_RESULTS
class Config:
frozen = True

View File

@ -7,7 +7,6 @@ from sqlalchemy.orm import Session
from danswer.chat.models import SectionRelevancePiece
from danswer.configs.chat_configs import DISABLE_LLM_DOC_RELEVANCE
from danswer.configs.chat_configs import MULTILINGUAL_QUERY_EXPANSION
from danswer.db.embedding_model import get_current_db_embedding_model
from danswer.db.models import User
from danswer.document_index.factory import get_default_document_index
@ -150,7 +149,6 @@ class SearchPipeline:
query=self.search_query,
document_index=self.document_index,
db_session=self.db_session,
multilingual_expansion_str=MULTILINGUAL_QUERY_EXPANSION,
retrieval_metrics_callback=self.retrieval_metrics_callback,
)

View File

@ -22,13 +22,11 @@ from danswer.search.models import InferenceSection
from danswer.search.models import MAX_METRICS_CONTENT
from danswer.search.models import RerankMetricsContainer
from danswer.search.models import SearchQuery
from danswer.search.models import SearchType
from danswer.secondary_llm_flows.chunk_usefulness import llm_batch_eval_sections
from danswer.utils.logger import setup_logger
from danswer.utils.threadpool_concurrency import FunctionCall
from danswer.utils.threadpool_concurrency import run_functions_in_parallel
from danswer.utils.timing import log_function_time
from shared_configs.configs import DEFAULT_CROSS_ENCODER_MODEL_NAME
logger = setup_logger()
@ -44,11 +42,6 @@ def _log_top_section_links(search_flow: str, sections: list[InferenceSection]) -
logger.info(f"Top links from {search_flow} search: {', '.join(top_links)}")
def should_rerank(query: SearchQuery) -> bool:
# Don't re-rank for keyword search
return query.search_type != SearchType.KEYWORD and not query.skip_rerank
def cleanup_chunks(chunks: list[InferenceChunkUncleaned]) -> list[InferenceChunk]:
def _remove_title(chunk: InferenceChunkUncleaned) -> str:
if not chunk.title or not chunk.content:
@ -84,7 +77,7 @@ def cleanup_chunks(chunks: list[InferenceChunkUncleaned]) -> list[InferenceChunk
@log_function_time(print_only=True)
def semantic_reranking(
query: str,
query: SearchQuery,
chunks: list[InferenceChunk],
model_min: int = CROSS_ENCODER_RANGE_MIN,
model_max: int = CROSS_ENCODER_RANGE_MAX,
@ -95,17 +88,24 @@ def semantic_reranking(
Note: this updates the chunks in place, it updates the chunk scores which came from retrieval
"""
# TODO update this
rerank_settings = query.rerank_settings
if not rerank_settings:
# Should never reach this part of the flow without reranking settings
raise RuntimeError("Reranking settings not found")
chunks_to_rerank = chunks[: rerank_settings.num_rerank]
cross_encoder = RerankingModel(
model_name=DEFAULT_CROSS_ENCODER_MODEL_NAME,
api_key=None,
model_name=rerank_settings.rerank_model_name,
api_key=rerank_settings.api_key,
)
passages = [
f"{chunk.semantic_identifier or chunk.title or ''}\n{chunk.content}"
for chunk in chunks
for chunk in chunks_to_rerank
]
sim_scores_floats = cross_encoder.predict(query=query, passages=passages)
sim_scores_floats = cross_encoder.predict(query=query.query, passages=passages)
# Old logic to handle multiple cross-encoders preserved but not used
sim_scores = [numpy.array(sim_scores_floats)]
@ -118,15 +118,17 @@ def semantic_reranking(
[enc_n_scores - cross_models_min for enc_n_scores in sim_scores]
) / len(sim_scores)
boosts = [translate_boost_count_to_multiplier(chunk.boost) for chunk in chunks]
recency_multiplier = [chunk.recency_bias for chunk in chunks]
boosts = [
translate_boost_count_to_multiplier(chunk.boost) for chunk in chunks_to_rerank
]
recency_multiplier = [chunk.recency_bias for chunk in chunks_to_rerank]
boosted_sim_scores = shifted_sim_scores * boosts * recency_multiplier
normalized_b_s_scores = (boosted_sim_scores + cross_models_min - model_min) / (
model_max - model_min
)
orig_indices = [i for i in range(len(normalized_b_s_scores))]
scored_results = list(
zip(normalized_b_s_scores, raw_sim_scores, chunks, orig_indices)
zip(normalized_b_s_scores, raw_sim_scores, chunks_to_rerank, orig_indices)
)
scored_results.sort(key=lambda x: x[0], reverse=True)
ranked_sim_scores, ranked_raw_scores, ranked_chunks, ranked_indices = zip(
@ -177,12 +179,16 @@ def rerank_sections(
"""
chunks_to_rerank = [section.center_chunk for section in sections_to_rerank]
if not query.rerank_settings:
# Should never reach this part of the flow without reranking settings
raise RuntimeError("Reranking settings not found")
ranked_chunks, _ = semantic_reranking(
query=query.query,
chunks=chunks_to_rerank[: query.num_rerank],
query=query,
chunks=chunks_to_rerank,
rerank_metrics_callback=rerank_metrics_callback,
)
lower_chunks = chunks_to_rerank[query.num_rerank :]
lower_chunks = chunks_to_rerank[query.rerank_settings.num_rerank :]
# Scores from rerank cannot be meaningfully combined with scores without rerank
# However the ordering is still important
@ -252,7 +258,7 @@ def search_postprocessing(
rerank_task_id = None
sections_yielded = False
if should_rerank(search_query):
if search_query.rerank_settings:
post_processing_tasks.append(
FunctionCall(
rerank_sections,

View File

@ -2,10 +2,10 @@ from sqlalchemy.orm import Session
from danswer.configs.chat_configs import BASE_RECENCY_DECAY
from danswer.configs.chat_configs import DISABLE_LLM_DOC_RELEVANCE
from danswer.configs.chat_configs import EDIT_KEYWORD_QUERY
from danswer.configs.chat_configs import FAVOR_RECENT_DECAY_MULTIPLIER
from danswer.configs.chat_configs import HYBRID_ALPHA
from danswer.configs.chat_configs import HYBRID_ALPHA_KEYWORD
from danswer.configs.chat_configs import NUM_POSTPROCESSED_RESULTS
from danswer.configs.chat_configs import NUM_RETURNED_HITS
from danswer.db.models import User
from danswer.llm.interfaces import LLM
@ -19,13 +19,14 @@ from danswer.search.models import SearchRequest
from danswer.search.models import SearchType
from danswer.search.preprocessing.access_filters import build_access_filters_for_user
from danswer.search.retrieval.search_runner import remove_stop_words_and_punctuation
from danswer.search.search_settings import get_search_settings
from danswer.secondary_llm_flows.source_filter import extract_source_filter
from danswer.secondary_llm_flows.time_filter import extract_time_filter
from danswer.utils.logger import setup_logger
from danswer.utils.threadpool_concurrency import FunctionCall
from danswer.utils.threadpool_concurrency import run_functions_in_parallel
from danswer.utils.timing import log_function_time
from shared_configs.configs import ENABLE_RERANKING_REAL_TIME_FLOW
logger = setup_logger()
@ -137,7 +138,8 @@ def retrieval_preprocessing(
all_query_terms = query.split()
processed_keywords = (
remove_stop_words_and_punctuation(all_query_terms)
if EDIT_KEYWORD_QUERY
# If the user is using a different language, don't edit the query or remove english stopwords
if not search_request.multilingual_expansion
else all_query_terms
)
@ -170,9 +172,15 @@ def retrieval_preprocessing(
)
llm_evaluation_type = LLMEvaluationType.SKIP
skip_rerank = search_request.skip_rerank
if skip_rerank is None:
skip_rerank = not ENABLE_RERANKING_REAL_TIME_FLOW
rerank_settings = search_request.rerank_settings
# If not explicitly specified by the query, use the current settings
if rerank_settings is None:
saved_search_settings = get_search_settings()
if not saved_search_settings:
rerank_settings = None
# For non-streaming flows, the rerank settings are applied at the search_request level
elif not saved_search_settings.disable_rerank_for_streaming:
rerank_settings = saved_search_settings.to_reranking_detail()
# Decays at 1 / (1 + (multiplier * num years))
if persona and persona.recency_bias == RecencyBiasSetting.NO_DECAY:
@ -201,7 +209,13 @@ def retrieval_preprocessing(
recency_bias_multiplier=recency_bias_multiplier,
num_hits=limit if limit is not None else NUM_RETURNED_HITS,
offset=offset or 0,
skip_rerank=skip_rerank,
rerank_settings=rerank_settings,
# Should match the LLM filtering to the same as the reranked, it's understood as this is the number of results
# the user wants to do heavier processing on, so do the same for the LLM if reranking is on
# if no reranking settings are set, then use the global default
max_llm_filter_sections=rerank_settings.num_rerank
if rerank_settings
else NUM_POSTPROCESSED_RESULTS,
chunks_above=search_request.chunks_above,
chunks_below=search_request.chunks_below,
full_doc=search_request.full_doc,

View File

@ -7,7 +7,6 @@ from nltk.stem import WordNetLemmatizer # type:ignore
from nltk.tokenize import word_tokenize # type:ignore
from sqlalchemy.orm import Session
from danswer.configs.chat_configs import MULTILINGUAL_QUERY_EXPANSION
from danswer.db.embedding_model import get_current_db_embedding_model
from danswer.document_index.interfaces import DocumentIndex
from danswer.natural_language_processing.search_nlp_models import EmbeddingModel
@ -19,6 +18,7 @@ from danswer.search.models import MAX_METRICS_CONTENT
from danswer.search.models import RetrievalMetricsContainer
from danswer.search.models import SearchQuery
from danswer.search.postprocessing.postprocessing import cleanup_chunks
from danswer.search.search_settings import get_multilingual_expansion
from danswer.search.utils import inference_section_from_chunks
from danswer.secondary_llm_flows.query_expansion import multilingual_query_expansion
from danswer.utils.logger import setup_logger
@ -150,13 +150,14 @@ def retrieve_chunks(
query: SearchQuery,
document_index: DocumentIndex,
db_session: Session,
multilingual_expansion_str: str | None = MULTILINGUAL_QUERY_EXPANSION,
retrieval_metrics_callback: Callable[[RetrievalMetricsContainer], None]
| None = None,
) -> list[InferenceChunk]:
"""Returns a list of the best chunks from an initial keyword/semantic/ hybrid search."""
multilingual_expansion = get_multilingual_expansion()
# Don't do query expansion on complex queries, rephrasings likely would not work well
if not multilingual_expansion_str or "\n" in query.query or "\r" in query.query:
if not multilingual_expansion or "\n" in query.query or "\r" in query.query:
top_chunks = doc_index_retrieval(
query=query, document_index=document_index, db_session=db_session
)
@ -166,7 +167,7 @@ def retrieve_chunks(
# Currently only uses query expansion on multilingual use cases
query_rephrases = multilingual_query_expansion(
query.query, multilingual_expansion_str
query.query, multilingual_expansion
)
# Just to be extra sure, add the original query.
query_rephrases.append(query.query)

View File

@ -0,0 +1,40 @@
from typing import cast
from danswer.configs.constants import KV_SEARCH_SETTINGS
from danswer.dynamic_configs.factory import get_dynamic_config_store
from danswer.dynamic_configs.interface import ConfigNotFoundError
from danswer.search.models import SavedSearchSettings
from danswer.utils.logger import setup_logger
logger = setup_logger()
def get_multilingual_expansion() -> list[str]:
search_settings = get_search_settings()
return search_settings.multilingual_expansion if search_settings else []
def get_search_settings() -> SavedSearchSettings | None:
"""Get all user configured search settings which affect the search pipeline
Note: KV store is used in this case since there is no need to rollback the value or any need to audit past values
Note: for now we can't cache this value because if the API server is scaled, the cache could be out of sync
if the value is updated by another process/instance of the API server. If this reads from an in memory cache like
reddis then it will be ok. Until then this has some performance implications (though minor)
"""
kv_store = get_dynamic_config_store()
try:
return SavedSearchSettings(**cast(dict, kv_store.load(KV_SEARCH_SETTINGS)))
except ConfigNotFoundError:
return None
except Exception as e:
logger.error(f"Error loading search settings: {e}")
# Wiping it so that next server startup, it can load the defaults
# or the user can set it via the API/UI
kv_store.delete(KV_SEARCH_SETTINGS)
return None
def update_search_settings(settings: SavedSearchSettings) -> None:
kv_store = get_dynamic_config_store()
kv_store.store(KV_SEARCH_SETTINGS, settings.dict())

View File

@ -1,12 +1,12 @@
from danswer.chat.chat_utils import combine_message_chain
from danswer.configs.chat_configs import LANGUAGE_CHAT_NAMING_HINT
from danswer.configs.chat_configs import MULTILINGUAL_QUERY_EXPANSION
from danswer.configs.model_configs import GEN_AI_HISTORY_CUTOFF
from danswer.db.models import ChatMessage
from danswer.llm.interfaces import LLM
from danswer.llm.utils import dict_based_prompt_to_langchain_prompt
from danswer.llm.utils import message_to_string
from danswer.prompts.chat_prompts import CHAT_NAMING
from danswer.search.search_settings import get_multilingual_expansion
from danswer.utils.logger import setup_logger
logger = setup_logger()
@ -22,7 +22,7 @@ def get_renamed_conversation_name(
language_hint = (
f"\n{LANGUAGE_CHAT_NAMING_HINT.strip()}"
if bool(MULTILINGUAL_QUERY_EXPANSION)
if bool(get_multilingual_expansion())
else ""
)

View File

@ -50,11 +50,10 @@ def llm_multilingual_query_expansion(query: str, language: str) -> str:
def multilingual_query_expansion(
query: str,
expansion_languages: str,
expansion_languages: list[str],
use_threads: bool = True,
) -> list[str]:
languages = expansion_languages.split(",")
languages = [language.strip() for language in languages]
languages = [language.strip() for language in expansion_languages]
if use_threads:
functions_with_args: list[tuple[Callable, tuple]] = [
(llm_multilingual_query_expansion, (query, language))

View File

@ -20,11 +20,14 @@ from danswer.db.models import IndexModelStatus
from danswer.db.models import User
from danswer.document_index.factory import get_default_document_index
from danswer.indexing.models import EmbeddingModelDetail
from danswer.search.models import SavedSearchSettings
from danswer.search.search_settings import get_search_settings
from danswer.search.search_settings import update_search_settings
from danswer.server.manage.models import FullModelVersionResponse
from danswer.server.models import IdReturn
from danswer.utils.logger import setup_logger
router = APIRouter(prefix="/secondary-index")
router = APIRouter(prefix="/search-settings")
logger = setup_logger()
@ -158,3 +161,18 @@ def get_embedding_models(
if next_model
else None,
)
@router.get("/get-search-settings")
def get_saved_search_settings(
_: User | None = Depends(current_admin_user),
) -> SavedSearchSettings | None:
return get_search_settings()
@router.post("/update-search-settings")
def update_saved_search_settings(
search_settings: SavedSearchSettings,
_: User | None = Depends(current_admin_user),
) -> None:
update_search_settings(search_settings)

View File

@ -4,6 +4,7 @@ from danswer.configs.constants import DocumentSource
from danswer.search.enums import LLMEvaluationType
from danswer.search.enums import SearchType
from danswer.search.models import ChunkContext
from danswer.search.models import RerankingDetails
from danswer.search.models import RetrievalDetails
from danswer.server.manage.models import StandardAnswer
@ -23,8 +24,8 @@ class DocumentSearchRequest(ChunkContext):
retrieval_options: RetrievalDetails
recency_bias_multiplier: float = 1.0
evaluation_type: LLMEvaluationType
# This is to forcibly skip (or run) the step, if None it uses the system defaults
skip_rerank: bool | None = None
# None to use system defaults for reranking
rerank_settings: RerankingDetails | None = None
class BasicCreateChatMessageRequest(ChunkContext):

View File

@ -63,7 +63,7 @@ def handle_search_request(
persona=None, # For simplicity, default settings should be good for this search
offset=search_request.retrieval_options.offset,
limit=search_request.retrieval_options.limit,
skip_rerank=search_request.skip_rerank,
rerank_settings=search_request.rerank_settings,
evaluation_type=search_request.evaluation_type,
chunks_above=search_request.chunks_above,
chunks_below=search_request.chunks_below,

View File

@ -28,7 +28,9 @@ ENABLE_RERANKING_REAL_TIME_FLOW = (
os.environ.get("ENABLE_RERANKING_REAL_TIME_FLOW", "").lower() == "true"
)
# Used for loading defaults for automatic deployments and dev flows
DEFAULT_CROSS_ENCODER_MODEL_NAME = "mixedbread-ai/mxbai-rerank-xsmall-v1"
DEFAULT_CROSS_ENCODER_API_KEY = os.environ.get("DEFAULT_CROSS_ENCODER_API_KEY")
# This controls the minimum number of pytorch "threads" to allocate to the embedding
# model. If torch finds more threads on its own, this value is not used.

View File

@ -37,7 +37,12 @@ def test_chunk_document() -> None:
passage_prefix=None,
)
chunks = chunk_document(document, embedder=embedder)
chunks = chunk_document(
document=document,
model_name=embedder.model_name,
provider_type=embedder.provider_type,
enable_multipass=False,
)
assert len(chunks) == 5
assert short_section_1 in chunks[0].content
assert short_section_3 in chunks[-1].content

View File

@ -35,7 +35,7 @@ ENABLE_RERANKING_REAL_TIME_FLOW="False"
# At the cost of indexing speed (~5x slower), query time is same speed
# Since reranking is turned off and multilingual retrieval is generally harder
# it is advised to turn this one on
ENABLE_MINI_CHUNK="True"
ENABLE_MULTIPASS_INDEXING="True"
# Using a stronger LLM will help with multilingual tasks
# Since documents may be in multiple languages, and there are additional instructions to respond

View File

@ -46,6 +46,6 @@ spec:
- configMapRef:
name: {{ .Values.config.envConfigMapName }}
env:
- name: ENABLE_MINI_CHUNK
- name: ENABLE_MULTIPASS_INDEXING
value: "{{ .Values.background.enableMiniChunk }}"
{{- include "danswer-stack.envSecrets" . | nindent 12}}

View File

@ -82,7 +82,7 @@ function Main() {
isLoading: isLoadingCurrentModel,
error: currentEmeddingModelError,
} = useSWR<CloudEmbeddingModel | HostedEmbeddingModel | null>(
"/api/secondary-index/get-current-embedding-model",
"/api/search-settings/get-current-embedding-model",
errorHandlingFetcher,
{ refreshInterval: 5000 } // 5 seconds
);
@ -97,7 +97,7 @@ function Main() {
isLoading: isLoadingFutureModel,
error: futureEmeddingModelError,
} = useSWR<CloudEmbeddingModel | HostedEmbeddingModel | null>(
"/api/secondary-index/get-secondary-embedding-model",
"/api/search-settings/get-secondary-embedding-model",
errorHandlingFetcher,
{ refreshInterval: 5000 } // 5 seconds
);
@ -139,7 +139,7 @@ function Main() {
}
const response = await fetch(
"/api/secondary-index/set-new-embedding-model",
"/api/search-settings/set-new-embedding-model",
{
method: "POST",
body: JSON.stringify(newModel),
@ -151,7 +151,7 @@ function Main() {
if (response.ok) {
setShowTentativeOpenProvider(null);
setShowTentativeModel(null);
mutate("/api/secondary-index/get-secondary-embedding-model");
mutate("/api/search-settings/get-secondary-embedding-model");
if (!connectors || !connectors.length) {
setShowAddConnectorPopup(true);
}
@ -161,12 +161,12 @@ function Main() {
};
const onCancel = async () => {
const response = await fetch("/api/secondary-index/cancel-new-embedding", {
const response = await fetch("/api/search-settings/cancel-new-embedding", {
method: "POST",
});
if (response.ok) {
setShowTentativeModel(null);
mutate("/api/secondary-index/get-secondary-embedding-model");
mutate("/api/search-settings/get-secondary-embedding-model");
} else {
alert(
`Failed to cancel embedding model update - ${await response.text()}`
@ -189,7 +189,7 @@ function Main() {
const onConfirmSelection = async (model: EmbeddingModelDescriptor) => {
const response = await fetch(
"/api/secondary-index/set-new-embedding-model",
"/api/search-settings/set-new-embedding-model",
{
method: "POST",
body: JSON.stringify(model),
@ -200,7 +200,7 @@ function Main() {
);
if (response.ok) {
setShowTentativeModel(null);
mutate("/api/secondary-index/get-secondary-embedding-model");
mutate("/api/search-settings/get-secondary-embedding-model");
if (!connectors || !connectors.length) {
setShowAddConnectorPopup(true);
}

View File

@ -51,7 +51,7 @@ export default async function Home() {
fetchSS("/manage/document-set"),
fetchAssistantsSS(),
fetchSS("/query/valid-tags"),
fetchSS("/secondary-index/get-embedding-models"),
fetchSS("/search-settings/get-embedding-models"),
fetchSS("/query/user-searches"),
];