Turn off Reranking for Streaming Flows (#770)

This commit is contained in:
Yuhong Sun 2023-11-26 16:45:23 -08:00 committed by GitHub
parent 2665bff78e
commit d291fea020
No known key found for this signature in database
GPG Key ID: 4AEE18F83AFDEB23
13 changed files with 36 additions and 19 deletions

View File

@ -339,6 +339,6 @@ if __name__ == "__main__":
if not MODEL_SERVER_HOST: if not MODEL_SERVER_HOST:
logger.info("Warming up Embedding Model(s)") logger.info("Warming up Embedding Model(s)")
warm_up_models(indexer_only=True) warm_up_models(indexer_only=True, skip_cross_encoders=True)
logger.info("Starting Indexing Loop") logger.info("Starting Indexing Loop")
update_loop() update_loop()

View File

@ -34,7 +34,12 @@ MIN_THREADS_ML_MODELS = int(os.environ.get("MIN_THREADS_ML_MODELS") or 1)
# Cross Encoder Settings # Cross Encoder Settings
# This following setting is for non-real-time-flows
SKIP_RERANKING = os.environ.get("SKIP_RERANKING", "").lower() == "true" SKIP_RERANKING = os.environ.get("SKIP_RERANKING", "").lower() == "true"
# This one is for real-time (streaming) flows
ENABLE_RERANKING_REAL_TIME_FLOW = (
os.environ.get("ENABLE_RERANKING_REAL_TIME_FLOW", "").lower() == "true"
)
# https://www.sbert.net/docs/pretrained-models/ce-msmarco.html # https://www.sbert.net/docs/pretrained-models/ce-msmarco.html
CROSS_ENCODER_MODEL_ENSEMBLE = [ CROSS_ENCODER_MODEL_ENSEMBLE = [
"cross-encoder/ms-marco-MiniLM-L-4-v2", "cross-encoder/ms-marco-MiniLM-L-4-v2",

View File

@ -180,7 +180,6 @@ def handle_message(
user=None, user=None,
db_session=db_session, db_session=db_session,
answer_generation_timeout=answer_generation_timeout, answer_generation_timeout=answer_generation_timeout,
real_time_flow=False,
enable_reflexion=reflexion, enable_reflexion=reflexion,
bypass_acl=bypass_acl, bypass_acl=bypass_acl,
) )
@ -205,6 +204,7 @@ def handle_message(
query=msg, query=msg,
filters=filters, filters=filters,
enable_auto_detect_filters=not disable_auto_detect_filters, enable_auto_detect_filters=not disable_auto_detect_filters,
real_time=False,
) )
) )
except Exception as e: except Exception as e:

View File

@ -11,6 +11,7 @@ from sqlalchemy.orm import Session
from danswer.configs.danswerbot_configs import DANSWER_BOT_RESPOND_EVERY_CHANNEL from danswer.configs.danswerbot_configs import DANSWER_BOT_RESPOND_EVERY_CHANNEL
from danswer.configs.danswerbot_configs import NOTIFY_SLACKBOT_NO_ANSWER from danswer.configs.danswerbot_configs import NOTIFY_SLACKBOT_NO_ANSWER
from danswer.configs.model_configs import SKIP_RERANKING
from danswer.danswerbot.slack.config import get_slack_bot_config_for_channel from danswer.danswerbot.slack.config import get_slack_bot_config_for_channel
from danswer.danswerbot.slack.constants import SLACK_CHANNEL_ID from danswer.danswerbot.slack.constants import SLACK_CHANNEL_ID
from danswer.danswerbot.slack.handlers.handle_feedback import handle_slack_feedback from danswer.danswerbot.slack.handlers.handle_feedback import handle_slack_feedback
@ -295,7 +296,7 @@ def process_slack_event(client: SocketModeClient, req: SocketModeRequest) -> Non
# without issue. # without issue.
if __name__ == "__main__": if __name__ == "__main__":
try: try:
warm_up_models() warm_up_models(skip_cross_encoders=SKIP_RERANKING)
socket_client = _get_socket_client() socket_client = _get_socket_client()
socket_client.socket_mode_request_listeners.append(process_slack_event) # type: ignore socket_client.socket_mode_request_listeners.append(process_slack_event) # type: ignore

View File

@ -48,7 +48,6 @@ def answer_qa_query(
db_session: Session, db_session: Session,
disable_generative_answer: bool = DISABLE_GENERATIVE_AI, disable_generative_answer: bool = DISABLE_GENERATIVE_AI,
answer_generation_timeout: int = QA_TIMEOUT, answer_generation_timeout: int = QA_TIMEOUT,
real_time_flow: bool = True,
enable_reflexion: bool = False, enable_reflexion: bool = False,
bypass_acl: bool = False, bypass_acl: bool = False,
retrieval_metrics_callback: Callable[[RetrievalMetricsContainer], None] retrieval_metrics_callback: Callable[[RetrievalMetricsContainer], None]
@ -118,7 +117,7 @@ def answer_qa_query(
try: try:
qa_model = get_default_qa_model( qa_model = get_default_qa_model(
timeout=answer_generation_timeout, real_time_flow=real_time_flow timeout=answer_generation_timeout, real_time_flow=question.real_time
) )
except Exception as e: except Exception as e:
return partial_response( return partial_response(
@ -159,7 +158,7 @@ def answer_qa_query(
) )
validity = None validity = None
if not real_time_flow and enable_reflexion and d_answer is not None: if not question.real_time and enable_reflexion and d_answer is not None:
validity = False validity = False
if d_answer.answer is not None: if d_answer.answer is not None:
validity = get_answer_validity(query, d_answer.answer) validity = get_answer_validity(query, d_answer.answer)

View File

@ -30,11 +30,11 @@ from danswer.configs.constants import AuthType
from danswer.configs.model_configs import ASYM_PASSAGE_PREFIX from danswer.configs.model_configs import ASYM_PASSAGE_PREFIX
from danswer.configs.model_configs import ASYM_QUERY_PREFIX from danswer.configs.model_configs import ASYM_QUERY_PREFIX
from danswer.configs.model_configs import DOCUMENT_ENCODER_MODEL from danswer.configs.model_configs import DOCUMENT_ENCODER_MODEL
from danswer.configs.model_configs import ENABLE_RERANKING_REAL_TIME_FLOW
from danswer.configs.model_configs import FAST_GEN_AI_MODEL_VERSION from danswer.configs.model_configs import FAST_GEN_AI_MODEL_VERSION
from danswer.configs.model_configs import GEN_AI_API_ENDPOINT from danswer.configs.model_configs import GEN_AI_API_ENDPOINT
from danswer.configs.model_configs import GEN_AI_MODEL_PROVIDER from danswer.configs.model_configs import GEN_AI_MODEL_PROVIDER
from danswer.configs.model_configs import GEN_AI_MODEL_VERSION from danswer.configs.model_configs import GEN_AI_MODEL_VERSION
from danswer.configs.model_configs import SKIP_RERANKING
from danswer.db.credentials import create_initial_public_credential from danswer.db.credentials import create_initial_public_credential
from danswer.direct_qa.factory import get_default_qa_model from danswer.direct_qa.factory import get_default_qa_model
from danswer.document_index.factory import get_default_document_index from danswer.document_index.factory import get_default_document_index
@ -186,8 +186,8 @@ def get_application() -> FastAPI:
f"Using multilingual flow with languages: {MULTILINGUAL_QUERY_EXPANSION}" f"Using multilingual flow with languages: {MULTILINGUAL_QUERY_EXPANSION}"
) )
if SKIP_RERANKING: if ENABLE_RERANKING_REAL_TIME_FLOW:
logger.info("Reranking step of search flow is disabled") logger.info("Reranking step of search flow is enabled.")
logger.info(f'Using Embedding model: "{DOCUMENT_ENCODER_MODEL}"') logger.info(f'Using Embedding model: "{DOCUMENT_ENCODER_MODEL}"')
if ASYM_QUERY_PREFIX or ASYM_PASSAGE_PREFIX: if ASYM_QUERY_PREFIX or ASYM_PASSAGE_PREFIX:
@ -200,7 +200,7 @@ def get_application() -> FastAPI:
) )
else: else:
logger.info("Warming up local NLP models.") logger.info("Warming up local NLP models.")
warm_up_models() warm_up_models(skip_cross_encoders=not ENABLE_RERANKING_REAL_TIME_FLOW)
if torch.cuda.is_available(): if torch.cuda.is_available():
logger.info("GPU is available") logger.info("GPU is available")

View File

@ -7,7 +7,7 @@ from danswer.configs.app_configs import DISABLE_LLM_CHUNK_FILTER
from danswer.configs.app_configs import NUM_RERANKED_RESULTS from danswer.configs.app_configs import NUM_RERANKED_RESULTS
from danswer.configs.app_configs import NUM_RETURNED_HITS from danswer.configs.app_configs import NUM_RETURNED_HITS
from danswer.configs.constants import DocumentSource from danswer.configs.constants import DocumentSource
from danswer.configs.model_configs import SKIP_RERANKING from danswer.configs.model_configs import ENABLE_RERANKING_REAL_TIME_FLOW
from danswer.indexing.models import DocAwareChunk from danswer.indexing.models import DocAwareChunk
from danswer.indexing.models import IndexChunk from danswer.indexing.models import IndexChunk
@ -55,7 +55,7 @@ class SearchQuery(BaseModel):
filters: IndexFilters filters: IndexFilters
favor_recent: bool favor_recent: bool
num_hits: int = NUM_RETURNED_HITS num_hits: int = NUM_RETURNED_HITS
skip_rerank: bool = SKIP_RERANKING skip_rerank: bool = not ENABLE_RERANKING_REAL_TIME_FLOW
# Only used if not skip_rerank # Only used if not skip_rerank
num_rerank: int | None = NUM_RERANKED_RESULTS num_rerank: int | None = NUM_RERANKED_RESULTS
skip_llm_chunk_filter: bool = DISABLE_LLM_CHUNK_FILTER skip_llm_chunk_filter: bool = DISABLE_LLM_CHUNK_FILTER

View File

@ -21,7 +21,6 @@ from danswer.configs.model_configs import DOCUMENT_ENCODER_MODEL
from danswer.configs.model_configs import INTENT_MODEL_VERSION from danswer.configs.model_configs import INTENT_MODEL_VERSION
from danswer.configs.model_configs import NORMALIZE_EMBEDDINGS from danswer.configs.model_configs import NORMALIZE_EMBEDDINGS
from danswer.configs.model_configs import QUERY_MAX_CONTEXT_SIZE from danswer.configs.model_configs import QUERY_MAX_CONTEXT_SIZE
from danswer.configs.model_configs import SKIP_RERANKING
from danswer.utils.logger import setup_logger from danswer.utils.logger import setup_logger
from shared_models.model_server_models import EmbedRequest from shared_models.model_server_models import EmbedRequest
from shared_models.model_server_models import EmbedResponse from shared_models.model_server_models import EmbedResponse
@ -294,7 +293,8 @@ class IntentModel:
def warm_up_models( def warm_up_models(
indexer_only: bool = False, skip_cross_encoders: bool = SKIP_RERANKING skip_cross_encoders: bool = False,
indexer_only: bool = False,
) -> None: ) -> None:
warm_up_str = ( warm_up_str = (
"Danswer is amazing! Check out our easy deployment guide at " "Danswer is amazing! Check out our easy deployment guide at "

View File

@ -16,8 +16,10 @@ from danswer.configs.app_configs import NUM_RERANKED_RESULTS
from danswer.configs.model_configs import ASYM_QUERY_PREFIX from danswer.configs.model_configs import ASYM_QUERY_PREFIX
from danswer.configs.model_configs import CROSS_ENCODER_RANGE_MAX from danswer.configs.model_configs import CROSS_ENCODER_RANGE_MAX
from danswer.configs.model_configs import CROSS_ENCODER_RANGE_MIN from danswer.configs.model_configs import CROSS_ENCODER_RANGE_MIN
from danswer.configs.model_configs import ENABLE_RERANKING_REAL_TIME_FLOW
from danswer.configs.model_configs import SIM_SCORE_RANGE_HIGH from danswer.configs.model_configs import SIM_SCORE_RANGE_HIGH
from danswer.configs.model_configs import SIM_SCORE_RANGE_LOW from danswer.configs.model_configs import SIM_SCORE_RANGE_LOW
from danswer.configs.model_configs import SKIP_RERANKING
from danswer.db.feedback import create_query_event from danswer.db.feedback import create_query_event
from danswer.db.feedback import update_query_event_retrieved_documents from danswer.db.feedback import update_query_event_retrieved_documents
from danswer.db.models import User from danswer.db.models import User
@ -392,7 +394,7 @@ def retrieve_chunks(
def should_rerank(query: SearchQuery) -> bool: def should_rerank(query: SearchQuery) -> bool:
# don't re-rank for keyword search # Don't re-rank for keyword search
return query.search_type != SearchType.KEYWORD and not query.skip_rerank return query.search_type != SearchType.KEYWORD and not query.skip_rerank
@ -556,6 +558,8 @@ def danswer_search_generator(
db_session: Session, db_session: Session,
document_index: DocumentIndex, document_index: DocumentIndex,
skip_llm_chunk_filter: bool = DISABLE_LLM_CHUNK_FILTER, skip_llm_chunk_filter: bool = DISABLE_LLM_CHUNK_FILTER,
skip_rerank_realtime: bool = not ENABLE_RERANKING_REAL_TIME_FLOW,
skip_rerank_non_realtime: bool = SKIP_RERANKING,
bypass_acl: bool = False, bypass_acl: bool = False,
retrieval_metrics_callback: Callable[[RetrievalMetricsContainer], None] retrieval_metrics_callback: Callable[[RetrievalMetricsContainer], None]
| None = None, | None = None,
@ -563,7 +567,7 @@ def danswer_search_generator(
) -> Iterator[list[InferenceChunk] | list[bool] | int]: ) -> Iterator[list[InferenceChunk] | list[bool] | int]:
"""The main entry point for search. This fetches the relevant documents from Vespa """The main entry point for search. This fetches the relevant documents from Vespa
based on the provided query (applying permissions / filters), does any specified based on the provided query (applying permissions / filters), does any specified
post-processing, and returns the results. It also create an entry in the query_event table post-processing, and returns the results. It also creates an entry in the query_event table
for this search event.""" for this search event."""
query_event_id = create_query_event( query_event_id = create_query_event(
query=question.query, query=question.query,
@ -583,6 +587,10 @@ def danswer_search_generator(
access_control_list=user_acl_filters, access_control_list=user_acl_filters,
) )
skip_reranking = (
skip_rerank_realtime if question.real_time else skip_rerank_non_realtime
)
search_query = SearchQuery( search_query = SearchQuery(
query=question.query, query=question.query,
search_type=question.search_type, search_type=question.search_type,
@ -591,6 +599,7 @@ def danswer_search_generator(
favor_recent=question.favor_recent favor_recent=question.favor_recent
if question.favor_recent is not None if question.favor_recent is not None
else False, else False,
skip_rerank=skip_reranking,
skip_llm_chunk_filter=skip_llm_chunk_filter, skip_llm_chunk_filter=skip_llm_chunk_filter,
) )

View File

@ -179,6 +179,9 @@ class QuestionRequest(BaseModel):
search_type: SearchType = SearchType.HYBRID search_type: SearchType = SearchType.HYBRID
enable_auto_detect_filters: bool = True enable_auto_detect_filters: bool = True
favor_recent: bool | None = None favor_recent: bool | None = None
# Is this a real-time/streaming call or a question where Danswer can take more time?
real_time: bool = True
# Pagination purposes, offset is in batches, not by document count
offset: int | None = None offset: int | None = None

View File

@ -84,6 +84,7 @@ def get_answer_for_question(
question = QuestionRequest( question = QuestionRequest(
query=query, query=query,
filters=filters, filters=filters,
real_time=False,
enable_auto_detect_filters=False, enable_auto_detect_filters=False,
) )
@ -96,7 +97,6 @@ def get_answer_for_question(
user=None, user=None,
db_session=db_session, db_session=db_session,
answer_generation_timeout=100, answer_generation_timeout=100,
real_time_flow=False,
enable_reflexion=False, enable_reflexion=False,
bypass_acl=True, bypass_acl=True,
retrieval_metrics_callback=retrieval_metrics.record_metric, retrieval_metrics_callback=retrieval_metrics.record_metric,

View File

@ -26,7 +26,7 @@ engine = get_sqlalchemy_engine()
def redirect_print_to_file(file: TextIO) -> Any: def redirect_print_to_file(file: TextIO) -> Any:
original_print = builtins.print original_print = builtins.print
def new_print(*args, **kwargs): def new_print(*args: Any, **kwargs: Any) -> Any:
kwargs["file"] = file kwargs["file"] = file
original_print(*args, **kwargs) original_print(*args, **kwargs)

View File

@ -45,7 +45,7 @@ services:
- SIM_SCORE_RANGE_HIGH=${SIM_SCORE_RANGE_HIGH:-} - SIM_SCORE_RANGE_HIGH=${SIM_SCORE_RANGE_HIGH:-}
- ASYM_QUERY_PREFIX=${ASYM_QUERY_PREFIX:-} - ASYM_QUERY_PREFIX=${ASYM_QUERY_PREFIX:-}
- ASYM_PASSAGE_PREFIX=${ASYM_PASSAGE_PREFIX:-} - ASYM_PASSAGE_PREFIX=${ASYM_PASSAGE_PREFIX:-}
- SKIP_RERANKING=${SKIP_RERANKING:-} - ENABLE_RERANKING_REAL_TIME_FLOW=${ENABLE_RERANKING_REAL_TIME_FLOW:-}
- QA_PROMPT_OVERRIDE=${QA_PROMPT_OVERRIDE:-} - QA_PROMPT_OVERRIDE=${QA_PROMPT_OVERRIDE:-}
- EDIT_KEYWORD_QUERY=${EDIT_KEYWORD_QUERY:-} - EDIT_KEYWORD_QUERY=${EDIT_KEYWORD_QUERY:-}
- MULTILINGUAL_QUERY_EXPANSION=${MULTILINGUAL_QUERY_EXPANSION:-} - MULTILINGUAL_QUERY_EXPANSION=${MULTILINGUAL_QUERY_EXPANSION:-}