mirror of
https://github.com/danswer-ai/danswer.git
synced 2025-03-26 17:51:54 +01:00
Turn off Reranking for Streaming Flows (#770)
This commit is contained in:
parent
2665bff78e
commit
d291fea020
@ -339,6 +339,6 @@ if __name__ == "__main__":
|
||||
|
||||
if not MODEL_SERVER_HOST:
|
||||
logger.info("Warming up Embedding Model(s)")
|
||||
warm_up_models(indexer_only=True)
|
||||
warm_up_models(indexer_only=True, skip_cross_encoders=True)
|
||||
logger.info("Starting Indexing Loop")
|
||||
update_loop()
|
||||
|
@ -34,7 +34,12 @@ MIN_THREADS_ML_MODELS = int(os.environ.get("MIN_THREADS_ML_MODELS") or 1)
|
||||
|
||||
|
||||
# Cross Encoder Settings
|
||||
# This following setting is for non-real-time-flows
|
||||
SKIP_RERANKING = os.environ.get("SKIP_RERANKING", "").lower() == "true"
|
||||
# This one is for real-time (streaming) flows
|
||||
ENABLE_RERANKING_REAL_TIME_FLOW = (
|
||||
os.environ.get("ENABLE_RERANKING_REAL_TIME_FLOW", "").lower() == "true"
|
||||
)
|
||||
# https://www.sbert.net/docs/pretrained-models/ce-msmarco.html
|
||||
CROSS_ENCODER_MODEL_ENSEMBLE = [
|
||||
"cross-encoder/ms-marco-MiniLM-L-4-v2",
|
||||
|
@ -180,7 +180,6 @@ def handle_message(
|
||||
user=None,
|
||||
db_session=db_session,
|
||||
answer_generation_timeout=answer_generation_timeout,
|
||||
real_time_flow=False,
|
||||
enable_reflexion=reflexion,
|
||||
bypass_acl=bypass_acl,
|
||||
)
|
||||
@ -205,6 +204,7 @@ def handle_message(
|
||||
query=msg,
|
||||
filters=filters,
|
||||
enable_auto_detect_filters=not disable_auto_detect_filters,
|
||||
real_time=False,
|
||||
)
|
||||
)
|
||||
except Exception as e:
|
||||
|
@ -11,6 +11,7 @@ from sqlalchemy.orm import Session
|
||||
|
||||
from danswer.configs.danswerbot_configs import DANSWER_BOT_RESPOND_EVERY_CHANNEL
|
||||
from danswer.configs.danswerbot_configs import NOTIFY_SLACKBOT_NO_ANSWER
|
||||
from danswer.configs.model_configs import SKIP_RERANKING
|
||||
from danswer.danswerbot.slack.config import get_slack_bot_config_for_channel
|
||||
from danswer.danswerbot.slack.constants import SLACK_CHANNEL_ID
|
||||
from danswer.danswerbot.slack.handlers.handle_feedback import handle_slack_feedback
|
||||
@ -295,7 +296,7 @@ def process_slack_event(client: SocketModeClient, req: SocketModeRequest) -> Non
|
||||
# without issue.
|
||||
if __name__ == "__main__":
|
||||
try:
|
||||
warm_up_models()
|
||||
warm_up_models(skip_cross_encoders=SKIP_RERANKING)
|
||||
|
||||
socket_client = _get_socket_client()
|
||||
socket_client.socket_mode_request_listeners.append(process_slack_event) # type: ignore
|
||||
|
@ -48,7 +48,6 @@ def answer_qa_query(
|
||||
db_session: Session,
|
||||
disable_generative_answer: bool = DISABLE_GENERATIVE_AI,
|
||||
answer_generation_timeout: int = QA_TIMEOUT,
|
||||
real_time_flow: bool = True,
|
||||
enable_reflexion: bool = False,
|
||||
bypass_acl: bool = False,
|
||||
retrieval_metrics_callback: Callable[[RetrievalMetricsContainer], None]
|
||||
@ -118,7 +117,7 @@ def answer_qa_query(
|
||||
|
||||
try:
|
||||
qa_model = get_default_qa_model(
|
||||
timeout=answer_generation_timeout, real_time_flow=real_time_flow
|
||||
timeout=answer_generation_timeout, real_time_flow=question.real_time
|
||||
)
|
||||
except Exception as e:
|
||||
return partial_response(
|
||||
@ -159,7 +158,7 @@ def answer_qa_query(
|
||||
)
|
||||
|
||||
validity = None
|
||||
if not real_time_flow and enable_reflexion and d_answer is not None:
|
||||
if not question.real_time and enable_reflexion and d_answer is not None:
|
||||
validity = False
|
||||
if d_answer.answer is not None:
|
||||
validity = get_answer_validity(query, d_answer.answer)
|
||||
|
@ -30,11 +30,11 @@ from danswer.configs.constants import AuthType
|
||||
from danswer.configs.model_configs import ASYM_PASSAGE_PREFIX
|
||||
from danswer.configs.model_configs import ASYM_QUERY_PREFIX
|
||||
from danswer.configs.model_configs import DOCUMENT_ENCODER_MODEL
|
||||
from danswer.configs.model_configs import ENABLE_RERANKING_REAL_TIME_FLOW
|
||||
from danswer.configs.model_configs import FAST_GEN_AI_MODEL_VERSION
|
||||
from danswer.configs.model_configs import GEN_AI_API_ENDPOINT
|
||||
from danswer.configs.model_configs import GEN_AI_MODEL_PROVIDER
|
||||
from danswer.configs.model_configs import GEN_AI_MODEL_VERSION
|
||||
from danswer.configs.model_configs import SKIP_RERANKING
|
||||
from danswer.db.credentials import create_initial_public_credential
|
||||
from danswer.direct_qa.factory import get_default_qa_model
|
||||
from danswer.document_index.factory import get_default_document_index
|
||||
@ -186,8 +186,8 @@ def get_application() -> FastAPI:
|
||||
f"Using multilingual flow with languages: {MULTILINGUAL_QUERY_EXPANSION}"
|
||||
)
|
||||
|
||||
if SKIP_RERANKING:
|
||||
logger.info("Reranking step of search flow is disabled")
|
||||
if ENABLE_RERANKING_REAL_TIME_FLOW:
|
||||
logger.info("Reranking step of search flow is enabled.")
|
||||
|
||||
logger.info(f'Using Embedding model: "{DOCUMENT_ENCODER_MODEL}"')
|
||||
if ASYM_QUERY_PREFIX or ASYM_PASSAGE_PREFIX:
|
||||
@ -200,7 +200,7 @@ def get_application() -> FastAPI:
|
||||
)
|
||||
else:
|
||||
logger.info("Warming up local NLP models.")
|
||||
warm_up_models()
|
||||
warm_up_models(skip_cross_encoders=not ENABLE_RERANKING_REAL_TIME_FLOW)
|
||||
|
||||
if torch.cuda.is_available():
|
||||
logger.info("GPU is available")
|
||||
|
@ -7,7 +7,7 @@ from danswer.configs.app_configs import DISABLE_LLM_CHUNK_FILTER
|
||||
from danswer.configs.app_configs import NUM_RERANKED_RESULTS
|
||||
from danswer.configs.app_configs import NUM_RETURNED_HITS
|
||||
from danswer.configs.constants import DocumentSource
|
||||
from danswer.configs.model_configs import SKIP_RERANKING
|
||||
from danswer.configs.model_configs import ENABLE_RERANKING_REAL_TIME_FLOW
|
||||
from danswer.indexing.models import DocAwareChunk
|
||||
from danswer.indexing.models import IndexChunk
|
||||
|
||||
@ -55,7 +55,7 @@ class SearchQuery(BaseModel):
|
||||
filters: IndexFilters
|
||||
favor_recent: bool
|
||||
num_hits: int = NUM_RETURNED_HITS
|
||||
skip_rerank: bool = SKIP_RERANKING
|
||||
skip_rerank: bool = not ENABLE_RERANKING_REAL_TIME_FLOW
|
||||
# Only used if not skip_rerank
|
||||
num_rerank: int | None = NUM_RERANKED_RESULTS
|
||||
skip_llm_chunk_filter: bool = DISABLE_LLM_CHUNK_FILTER
|
||||
|
@ -21,7 +21,6 @@ from danswer.configs.model_configs import DOCUMENT_ENCODER_MODEL
|
||||
from danswer.configs.model_configs import INTENT_MODEL_VERSION
|
||||
from danswer.configs.model_configs import NORMALIZE_EMBEDDINGS
|
||||
from danswer.configs.model_configs import QUERY_MAX_CONTEXT_SIZE
|
||||
from danswer.configs.model_configs import SKIP_RERANKING
|
||||
from danswer.utils.logger import setup_logger
|
||||
from shared_models.model_server_models import EmbedRequest
|
||||
from shared_models.model_server_models import EmbedResponse
|
||||
@ -294,7 +293,8 @@ class IntentModel:
|
||||
|
||||
|
||||
def warm_up_models(
|
||||
indexer_only: bool = False, skip_cross_encoders: bool = SKIP_RERANKING
|
||||
skip_cross_encoders: bool = False,
|
||||
indexer_only: bool = False,
|
||||
) -> None:
|
||||
warm_up_str = (
|
||||
"Danswer is amazing! Check out our easy deployment guide at "
|
||||
|
@ -16,8 +16,10 @@ from danswer.configs.app_configs import NUM_RERANKED_RESULTS
|
||||
from danswer.configs.model_configs import ASYM_QUERY_PREFIX
|
||||
from danswer.configs.model_configs import CROSS_ENCODER_RANGE_MAX
|
||||
from danswer.configs.model_configs import CROSS_ENCODER_RANGE_MIN
|
||||
from danswer.configs.model_configs import ENABLE_RERANKING_REAL_TIME_FLOW
|
||||
from danswer.configs.model_configs import SIM_SCORE_RANGE_HIGH
|
||||
from danswer.configs.model_configs import SIM_SCORE_RANGE_LOW
|
||||
from danswer.configs.model_configs import SKIP_RERANKING
|
||||
from danswer.db.feedback import create_query_event
|
||||
from danswer.db.feedback import update_query_event_retrieved_documents
|
||||
from danswer.db.models import User
|
||||
@ -392,7 +394,7 @@ def retrieve_chunks(
|
||||
|
||||
|
||||
def should_rerank(query: SearchQuery) -> bool:
|
||||
# don't re-rank for keyword search
|
||||
# Don't re-rank for keyword search
|
||||
return query.search_type != SearchType.KEYWORD and not query.skip_rerank
|
||||
|
||||
|
||||
@ -556,6 +558,8 @@ def danswer_search_generator(
|
||||
db_session: Session,
|
||||
document_index: DocumentIndex,
|
||||
skip_llm_chunk_filter: bool = DISABLE_LLM_CHUNK_FILTER,
|
||||
skip_rerank_realtime: bool = not ENABLE_RERANKING_REAL_TIME_FLOW,
|
||||
skip_rerank_non_realtime: bool = SKIP_RERANKING,
|
||||
bypass_acl: bool = False,
|
||||
retrieval_metrics_callback: Callable[[RetrievalMetricsContainer], None]
|
||||
| None = None,
|
||||
@ -563,7 +567,7 @@ def danswer_search_generator(
|
||||
) -> Iterator[list[InferenceChunk] | list[bool] | int]:
|
||||
"""The main entry point for search. This fetches the relevant documents from Vespa
|
||||
based on the provided query (applying permissions / filters), does any specified
|
||||
post-processing, and returns the results. It also create an entry in the query_event table
|
||||
post-processing, and returns the results. It also creates an entry in the query_event table
|
||||
for this search event."""
|
||||
query_event_id = create_query_event(
|
||||
query=question.query,
|
||||
@ -583,6 +587,10 @@ def danswer_search_generator(
|
||||
access_control_list=user_acl_filters,
|
||||
)
|
||||
|
||||
skip_reranking = (
|
||||
skip_rerank_realtime if question.real_time else skip_rerank_non_realtime
|
||||
)
|
||||
|
||||
search_query = SearchQuery(
|
||||
query=question.query,
|
||||
search_type=question.search_type,
|
||||
@ -591,6 +599,7 @@ def danswer_search_generator(
|
||||
favor_recent=question.favor_recent
|
||||
if question.favor_recent is not None
|
||||
else False,
|
||||
skip_rerank=skip_reranking,
|
||||
skip_llm_chunk_filter=skip_llm_chunk_filter,
|
||||
)
|
||||
|
||||
|
@ -179,6 +179,9 @@ class QuestionRequest(BaseModel):
|
||||
search_type: SearchType = SearchType.HYBRID
|
||||
enable_auto_detect_filters: bool = True
|
||||
favor_recent: bool | None = None
|
||||
# Is this a real-time/streaming call or a question where Danswer can take more time?
|
||||
real_time: bool = True
|
||||
# Pagination purposes, offset is in batches, not by document count
|
||||
offset: int | None = None
|
||||
|
||||
|
||||
|
@ -84,6 +84,7 @@ def get_answer_for_question(
|
||||
question = QuestionRequest(
|
||||
query=query,
|
||||
filters=filters,
|
||||
real_time=False,
|
||||
enable_auto_detect_filters=False,
|
||||
)
|
||||
|
||||
@ -96,7 +97,6 @@ def get_answer_for_question(
|
||||
user=None,
|
||||
db_session=db_session,
|
||||
answer_generation_timeout=100,
|
||||
real_time_flow=False,
|
||||
enable_reflexion=False,
|
||||
bypass_acl=True,
|
||||
retrieval_metrics_callback=retrieval_metrics.record_metric,
|
||||
|
@ -26,7 +26,7 @@ engine = get_sqlalchemy_engine()
|
||||
def redirect_print_to_file(file: TextIO) -> Any:
|
||||
original_print = builtins.print
|
||||
|
||||
def new_print(*args, **kwargs):
|
||||
def new_print(*args: Any, **kwargs: Any) -> Any:
|
||||
kwargs["file"] = file
|
||||
original_print(*args, **kwargs)
|
||||
|
||||
|
@ -45,7 +45,7 @@ services:
|
||||
- SIM_SCORE_RANGE_HIGH=${SIM_SCORE_RANGE_HIGH:-}
|
||||
- ASYM_QUERY_PREFIX=${ASYM_QUERY_PREFIX:-}
|
||||
- ASYM_PASSAGE_PREFIX=${ASYM_PASSAGE_PREFIX:-}
|
||||
- SKIP_RERANKING=${SKIP_RERANKING:-}
|
||||
- ENABLE_RERANKING_REAL_TIME_FLOW=${ENABLE_RERANKING_REAL_TIME_FLOW:-}
|
||||
- QA_PROMPT_OVERRIDE=${QA_PROMPT_OVERRIDE:-}
|
||||
- EDIT_KEYWORD_QUERY=${EDIT_KEYWORD_QUERY:-}
|
||||
- MULTILINGUAL_QUERY_EXPANSION=${MULTILINGUAL_QUERY_EXPANSION:-}
|
||||
|
Loading…
x
Reference in New Issue
Block a user