Turn off Reranking for Streaming Flows (#770)

This commit is contained in:
Yuhong Sun 2023-11-26 16:45:23 -08:00 committed by GitHub
parent 2665bff78e
commit d291fea020
No known key found for this signature in database
GPG Key ID: 4AEE18F83AFDEB23
13 changed files with 36 additions and 19 deletions

View File

@ -339,6 +339,6 @@ if __name__ == "__main__":
if not MODEL_SERVER_HOST:
logger.info("Warming up Embedding Model(s)")
warm_up_models(indexer_only=True)
warm_up_models(indexer_only=True, skip_cross_encoders=True)
logger.info("Starting Indexing Loop")
update_loop()

View File

@ -34,7 +34,12 @@ MIN_THREADS_ML_MODELS = int(os.environ.get("MIN_THREADS_ML_MODELS") or 1)
# Cross Encoder Settings
# This following setting is for non-real-time-flows
SKIP_RERANKING = os.environ.get("SKIP_RERANKING", "").lower() == "true"
# This one is for real-time (streaming) flows
ENABLE_RERANKING_REAL_TIME_FLOW = (
os.environ.get("ENABLE_RERANKING_REAL_TIME_FLOW", "").lower() == "true"
)
# https://www.sbert.net/docs/pretrained-models/ce-msmarco.html
CROSS_ENCODER_MODEL_ENSEMBLE = [
"cross-encoder/ms-marco-MiniLM-L-4-v2",

View File

@ -180,7 +180,6 @@ def handle_message(
user=None,
db_session=db_session,
answer_generation_timeout=answer_generation_timeout,
real_time_flow=False,
enable_reflexion=reflexion,
bypass_acl=bypass_acl,
)
@ -205,6 +204,7 @@ def handle_message(
query=msg,
filters=filters,
enable_auto_detect_filters=not disable_auto_detect_filters,
real_time=False,
)
)
except Exception as e:

View File

@ -11,6 +11,7 @@ from sqlalchemy.orm import Session
from danswer.configs.danswerbot_configs import DANSWER_BOT_RESPOND_EVERY_CHANNEL
from danswer.configs.danswerbot_configs import NOTIFY_SLACKBOT_NO_ANSWER
from danswer.configs.model_configs import SKIP_RERANKING
from danswer.danswerbot.slack.config import get_slack_bot_config_for_channel
from danswer.danswerbot.slack.constants import SLACK_CHANNEL_ID
from danswer.danswerbot.slack.handlers.handle_feedback import handle_slack_feedback
@ -295,7 +296,7 @@ def process_slack_event(client: SocketModeClient, req: SocketModeRequest) -> Non
# without issue.
if __name__ == "__main__":
try:
warm_up_models()
warm_up_models(skip_cross_encoders=SKIP_RERANKING)
socket_client = _get_socket_client()
socket_client.socket_mode_request_listeners.append(process_slack_event) # type: ignore

View File

@ -48,7 +48,6 @@ def answer_qa_query(
db_session: Session,
disable_generative_answer: bool = DISABLE_GENERATIVE_AI,
answer_generation_timeout: int = QA_TIMEOUT,
real_time_flow: bool = True,
enable_reflexion: bool = False,
bypass_acl: bool = False,
retrieval_metrics_callback: Callable[[RetrievalMetricsContainer], None]
@ -118,7 +117,7 @@ def answer_qa_query(
try:
qa_model = get_default_qa_model(
timeout=answer_generation_timeout, real_time_flow=real_time_flow
timeout=answer_generation_timeout, real_time_flow=question.real_time
)
except Exception as e:
return partial_response(
@ -159,7 +158,7 @@ def answer_qa_query(
)
validity = None
if not real_time_flow and enable_reflexion and d_answer is not None:
if not question.real_time and enable_reflexion and d_answer is not None:
validity = False
if d_answer.answer is not None:
validity = get_answer_validity(query, d_answer.answer)

View File

@ -30,11 +30,11 @@ from danswer.configs.constants import AuthType
from danswer.configs.model_configs import ASYM_PASSAGE_PREFIX
from danswer.configs.model_configs import ASYM_QUERY_PREFIX
from danswer.configs.model_configs import DOCUMENT_ENCODER_MODEL
from danswer.configs.model_configs import ENABLE_RERANKING_REAL_TIME_FLOW
from danswer.configs.model_configs import FAST_GEN_AI_MODEL_VERSION
from danswer.configs.model_configs import GEN_AI_API_ENDPOINT
from danswer.configs.model_configs import GEN_AI_MODEL_PROVIDER
from danswer.configs.model_configs import GEN_AI_MODEL_VERSION
from danswer.configs.model_configs import SKIP_RERANKING
from danswer.db.credentials import create_initial_public_credential
from danswer.direct_qa.factory import get_default_qa_model
from danswer.document_index.factory import get_default_document_index
@ -186,8 +186,8 @@ def get_application() -> FastAPI:
f"Using multilingual flow with languages: {MULTILINGUAL_QUERY_EXPANSION}"
)
if SKIP_RERANKING:
logger.info("Reranking step of search flow is disabled")
if ENABLE_RERANKING_REAL_TIME_FLOW:
logger.info("Reranking step of search flow is enabled.")
logger.info(f'Using Embedding model: "{DOCUMENT_ENCODER_MODEL}"')
if ASYM_QUERY_PREFIX or ASYM_PASSAGE_PREFIX:
@ -200,7 +200,7 @@ def get_application() -> FastAPI:
)
else:
logger.info("Warming up local NLP models.")
warm_up_models()
warm_up_models(skip_cross_encoders=not ENABLE_RERANKING_REAL_TIME_FLOW)
if torch.cuda.is_available():
logger.info("GPU is available")

View File

@ -7,7 +7,7 @@ from danswer.configs.app_configs import DISABLE_LLM_CHUNK_FILTER
from danswer.configs.app_configs import NUM_RERANKED_RESULTS
from danswer.configs.app_configs import NUM_RETURNED_HITS
from danswer.configs.constants import DocumentSource
from danswer.configs.model_configs import SKIP_RERANKING
from danswer.configs.model_configs import ENABLE_RERANKING_REAL_TIME_FLOW
from danswer.indexing.models import DocAwareChunk
from danswer.indexing.models import IndexChunk
@ -55,7 +55,7 @@ class SearchQuery(BaseModel):
filters: IndexFilters
favor_recent: bool
num_hits: int = NUM_RETURNED_HITS
skip_rerank: bool = SKIP_RERANKING
skip_rerank: bool = not ENABLE_RERANKING_REAL_TIME_FLOW
# Only used if not skip_rerank
num_rerank: int | None = NUM_RERANKED_RESULTS
skip_llm_chunk_filter: bool = DISABLE_LLM_CHUNK_FILTER

View File

@ -21,7 +21,6 @@ from danswer.configs.model_configs import DOCUMENT_ENCODER_MODEL
from danswer.configs.model_configs import INTENT_MODEL_VERSION
from danswer.configs.model_configs import NORMALIZE_EMBEDDINGS
from danswer.configs.model_configs import QUERY_MAX_CONTEXT_SIZE
from danswer.configs.model_configs import SKIP_RERANKING
from danswer.utils.logger import setup_logger
from shared_models.model_server_models import EmbedRequest
from shared_models.model_server_models import EmbedResponse
@ -294,7 +293,8 @@ class IntentModel:
def warm_up_models(
indexer_only: bool = False, skip_cross_encoders: bool = SKIP_RERANKING
skip_cross_encoders: bool = False,
indexer_only: bool = False,
) -> None:
warm_up_str = (
"Danswer is amazing! Check out our easy deployment guide at "

View File

@ -16,8 +16,10 @@ from danswer.configs.app_configs import NUM_RERANKED_RESULTS
from danswer.configs.model_configs import ASYM_QUERY_PREFIX
from danswer.configs.model_configs import CROSS_ENCODER_RANGE_MAX
from danswer.configs.model_configs import CROSS_ENCODER_RANGE_MIN
from danswer.configs.model_configs import ENABLE_RERANKING_REAL_TIME_FLOW
from danswer.configs.model_configs import SIM_SCORE_RANGE_HIGH
from danswer.configs.model_configs import SIM_SCORE_RANGE_LOW
from danswer.configs.model_configs import SKIP_RERANKING
from danswer.db.feedback import create_query_event
from danswer.db.feedback import update_query_event_retrieved_documents
from danswer.db.models import User
@ -392,7 +394,7 @@ def retrieve_chunks(
def should_rerank(query: SearchQuery) -> bool:
# don't re-rank for keyword search
# Don't re-rank for keyword search
return query.search_type != SearchType.KEYWORD and not query.skip_rerank
@ -556,6 +558,8 @@ def danswer_search_generator(
db_session: Session,
document_index: DocumentIndex,
skip_llm_chunk_filter: bool = DISABLE_LLM_CHUNK_FILTER,
skip_rerank_realtime: bool = not ENABLE_RERANKING_REAL_TIME_FLOW,
skip_rerank_non_realtime: bool = SKIP_RERANKING,
bypass_acl: bool = False,
retrieval_metrics_callback: Callable[[RetrievalMetricsContainer], None]
| None = None,
@ -563,7 +567,7 @@ def danswer_search_generator(
) -> Iterator[list[InferenceChunk] | list[bool] | int]:
"""The main entry point for search. This fetches the relevant documents from Vespa
based on the provided query (applying permissions / filters), does any specified
post-processing, and returns the results. It also create an entry in the query_event table
post-processing, and returns the results. It also creates an entry in the query_event table
for this search event."""
query_event_id = create_query_event(
query=question.query,
@ -583,6 +587,10 @@ def danswer_search_generator(
access_control_list=user_acl_filters,
)
skip_reranking = (
skip_rerank_realtime if question.real_time else skip_rerank_non_realtime
)
search_query = SearchQuery(
query=question.query,
search_type=question.search_type,
@ -591,6 +599,7 @@ def danswer_search_generator(
favor_recent=question.favor_recent
if question.favor_recent is not None
else False,
skip_rerank=skip_reranking,
skip_llm_chunk_filter=skip_llm_chunk_filter,
)

View File

@ -179,6 +179,9 @@ class QuestionRequest(BaseModel):
search_type: SearchType = SearchType.HYBRID
enable_auto_detect_filters: bool = True
favor_recent: bool | None = None
# Is this a real-time/streaming call or a question where Danswer can take more time?
real_time: bool = True
# Pagination purposes, offset is in batches, not by document count
offset: int | None = None

View File

@ -84,6 +84,7 @@ def get_answer_for_question(
question = QuestionRequest(
query=query,
filters=filters,
real_time=False,
enable_auto_detect_filters=False,
)
@ -96,7 +97,6 @@ def get_answer_for_question(
user=None,
db_session=db_session,
answer_generation_timeout=100,
real_time_flow=False,
enable_reflexion=False,
bypass_acl=True,
retrieval_metrics_callback=retrieval_metrics.record_metric,

View File

@ -26,7 +26,7 @@ engine = get_sqlalchemy_engine()
def redirect_print_to_file(file: TextIO) -> Any:
original_print = builtins.print
def new_print(*args, **kwargs):
def new_print(*args: Any, **kwargs: Any) -> Any:
kwargs["file"] = file
original_print(*args, **kwargs)

View File

@ -45,7 +45,7 @@ services:
- SIM_SCORE_RANGE_HIGH=${SIM_SCORE_RANGE_HIGH:-}
- ASYM_QUERY_PREFIX=${ASYM_QUERY_PREFIX:-}
- ASYM_PASSAGE_PREFIX=${ASYM_PASSAGE_PREFIX:-}
- SKIP_RERANKING=${SKIP_RERANKING:-}
- ENABLE_RERANKING_REAL_TIME_FLOW=${ENABLE_RERANKING_REAL_TIME_FLOW:-}
- QA_PROMPT_OVERRIDE=${QA_PROMPT_OVERRIDE:-}
- EDIT_KEYWORD_QUERY=${EDIT_KEYWORD_QUERY:-}
- MULTILINGUAL_QUERY_EXPANSION=${MULTILINGUAL_QUERY_EXPANSION:-}