From 8cd1eda8b1d94290cb9468ed79992d6a957e5633 Mon Sep 17 00:00:00 2001 From: Yuhong Sun Date: Thu, 8 Aug 2024 21:33:49 -0700 Subject: [PATCH] Rework Rerankers (#2093) --- .../search_nlp_models.py | 15 ++++-- .../search/postprocessing/postprocessing.py | 16 +++--- backend/model_server/encoders.py | 53 +++++++------------ backend/model_server/main.py | 4 +- backend/model_server/management_endpoints.py | 7 +++ backend/shared_configs/configs.py | 5 +- backend/shared_configs/model_server_models.py | 4 +- 7 files changed, 55 insertions(+), 49 deletions(-) diff --git a/backend/danswer/natural_language_processing/search_nlp_models.py b/backend/danswer/natural_language_processing/search_nlp_models.py index f25a07dc2..6757fada2 100644 --- a/backend/danswer/natural_language_processing/search_nlp_models.py +++ b/backend/danswer/natural_language_processing/search_nlp_models.py @@ -190,17 +190,26 @@ class EmbeddingModel: ) -class CrossEncoderEnsembleModel: +class RerankingModel: def __init__( self, + model_name: str, + api_key: str | None, model_server_host: str = MODEL_SERVER_HOST, model_server_port: int = MODEL_SERVER_PORT, ) -> None: model_server_url = build_model_server_url(model_server_host, model_server_port) self.rerank_server_endpoint = model_server_url + "/encoder/cross-encoder-scores" + self.model_name = model_name + self.api_key = api_key - def predict(self, query: str, passages: list[str]) -> list[list[float] | None]: - rerank_request = RerankRequest(query=query, documents=passages) + def predict(self, query: str, passages: list[str]) -> list[float]: + rerank_request = RerankRequest( + query=query, + documents=passages, + model_name=self.model_name, + api_key=self.api_key, + ) response = requests.post( self.rerank_server_endpoint, json=rerank_request.dict() diff --git a/backend/danswer/search/postprocessing/postprocessing.py b/backend/danswer/search/postprocessing/postprocessing.py index fdc2e5f4c..aa3995b5e 100644 --- a/backend/danswer/search/postprocessing/postprocessing.py +++ b/backend/danswer/search/postprocessing/postprocessing.py @@ -13,9 +13,7 @@ from danswer.document_index.document_index_utils import ( translate_boost_count_to_multiplier, ) from danswer.llm.interfaces import LLM -from danswer.natural_language_processing.search_nlp_models import ( - CrossEncoderEnsembleModel, -) +from danswer.natural_language_processing.search_nlp_models import RerankingModel from danswer.search.enums import LLMEvaluationType from danswer.search.models import ChunkMetric from danswer.search.models import InferenceChunk @@ -30,6 +28,7 @@ from danswer.utils.logger import setup_logger from danswer.utils.threadpool_concurrency import FunctionCall from danswer.utils.threadpool_concurrency import run_functions_in_parallel from danswer.utils.timing import log_function_time +from shared_configs.configs import DEFAULT_CROSS_ENCODER_MODEL_NAME logger = setup_logger() @@ -96,15 +95,20 @@ def semantic_reranking( Note: this updates the chunks in place, it updates the chunk scores which came from retrieval """ - cross_encoders = CrossEncoderEnsembleModel() + # TODO update this + cross_encoder = RerankingModel( + model_name=DEFAULT_CROSS_ENCODER_MODEL_NAME, + api_key=None, + ) passages = [ f"{chunk.semantic_identifier or chunk.title or ''}\n{chunk.content}" for chunk in chunks ] - sim_scores_floats = cross_encoders.predict(query=query, passages=passages) + sim_scores_floats = cross_encoder.predict(query=query, passages=passages) - sim_scores = [numpy.array(scores) for scores in sim_scores_floats] + # Old logic to handle multiple cross-encoders preserved but not used + sim_scores = [numpy.array(sim_scores_floats)] raw_sim_scores = cast(numpy.ndarray, sum(sim_scores) / len(sim_scores)) diff --git a/backend/model_server/encoders.py b/backend/model_server/encoders.py index f12aeb89b..cda6cb139 100644 --- a/backend/model_server/encoders.py +++ b/backend/model_server/encoders.py @@ -1,4 +1,3 @@ -import gc import json from typing import Any from typing import Optional @@ -25,8 +24,7 @@ from model_server.constants import EmbeddingModelTextType from model_server.constants import EmbeddingProvider from model_server.constants import MODEL_WARM_UP_STRING from model_server.utils import simple_log_function_time -from shared_configs.configs import CROSS_EMBED_CONTEXT_SIZE -from shared_configs.configs import CROSS_ENCODER_MODEL_ENSEMBLE +from shared_configs.configs import DEFAULT_CROSS_ENCODER_MODEL_NAME from shared_configs.configs import INDEXING_ONLY from shared_configs.enums import EmbedTextType from shared_configs.model_server_models import Embedding @@ -42,7 +40,8 @@ logger = setup_logger() router = APIRouter(prefix="/encoder") _GLOBAL_MODELS_DICT: dict[str, "SentenceTransformer"] = {} -_RERANK_MODELS: Optional[list["CrossEncoder"]] = None +_RERANK_MODEL: Optional["CrossEncoder"] = None + # If we are not only indexing, dont want retry very long _RETRY_DELAY = 10 if INDEXING_ONLY else 0.1 _RETRY_TRIES = 10 if INDEXING_ONLY else 2 @@ -229,32 +228,22 @@ def get_embedding_model( return _GLOBAL_MODELS_DICT[model_name] -def get_local_reranking_model_ensemble( - model_names: list[str] = CROSS_ENCODER_MODEL_ENSEMBLE, - max_context_length: int = CROSS_EMBED_CONTEXT_SIZE, -) -> list[CrossEncoder]: - global _RERANK_MODELS - if _RERANK_MODELS is None or max_context_length != _RERANK_MODELS[0].max_length: - del _RERANK_MODELS - gc.collect() - - _RERANK_MODELS = [] - for model_name in model_names: - logger.info(f"Loading {model_name}") - model = CrossEncoder(model_name) - model.max_length = max_context_length - _RERANK_MODELS.append(model) - return _RERANK_MODELS +def get_local_reranking_model( + model_name: str = DEFAULT_CROSS_ENCODER_MODEL_NAME, +) -> CrossEncoder: + global _RERANK_MODEL + if _RERANK_MODEL is None: + logger.info(f"Loading {model_name}") + model = CrossEncoder(model_name) + _RERANK_MODEL = model + return _RERANK_MODEL -def warm_up_cross_encoders() -> None: - logger.info(f"Warming up Cross-Encoders: {CROSS_ENCODER_MODEL_ENSEMBLE}") +def warm_up_cross_encoder() -> None: + logger.info(f"Warming up Cross-Encoder: {DEFAULT_CROSS_ENCODER_MODEL_NAME}") - cross_encoders = get_local_reranking_model_ensemble() - [ - cross_encoder.predict((MODEL_WARM_UP_STRING, MODEL_WARM_UP_STRING)) - for cross_encoder in cross_encoders - ] + cross_encoder = get_local_reranking_model() + cross_encoder.predict((MODEL_WARM_UP_STRING, MODEL_WARM_UP_STRING)) @simple_log_function_time() @@ -325,13 +314,9 @@ def embed_text( @simple_log_function_time() -def calc_sim_scores(query: str, docs: list[str]) -> list[list[float] | None]: - cross_encoders = get_local_reranking_model_ensemble() - sim_scores = [ - encoder.predict([(query, doc) for doc in docs]).tolist() # type: ignore - for encoder in cross_encoders - ] - return sim_scores +def calc_sim_scores(query: str, docs: list[str]) -> list[float]: + cross_encoder = get_local_reranking_model() + return cross_encoder.predict([(query, doc) for doc in docs]).tolist() # type: ignore @router.post("/bi-encoder-embed") diff --git a/backend/model_server/main.py b/backend/model_server/main.py index 87059d634..efb504c6b 100644 --- a/backend/model_server/main.py +++ b/backend/model_server/main.py @@ -15,7 +15,7 @@ from danswer.utils.logger import setup_logger from model_server.custom_models import router as custom_models_router from model_server.custom_models import warm_up_intent_model from model_server.encoders import router as encoders_router -from model_server.encoders import warm_up_cross_encoders +from model_server.encoders import warm_up_cross_encoder from model_server.management_endpoints import router as management_router from shared_configs.configs import ENABLE_RERANKING_ASYNC_FLOW from shared_configs.configs import ENABLE_RERANKING_REAL_TIME_FLOW @@ -65,7 +65,7 @@ async def lifespan(app: FastAPI) -> AsyncGenerator: if not INDEXING_ONLY: warm_up_intent_model() if ENABLE_RERANKING_REAL_TIME_FLOW or ENABLE_RERANKING_ASYNC_FLOW: - warm_up_cross_encoders() + warm_up_cross_encoder() else: logger.info("This model server should only run document indexing.") diff --git a/backend/model_server/management_endpoints.py b/backend/model_server/management_endpoints.py index fc1b8901e..d2d45d69d 100644 --- a/backend/model_server/management_endpoints.py +++ b/backend/model_server/management_endpoints.py @@ -1,3 +1,4 @@ +import torch from fastapi import APIRouter from fastapi import Response @@ -7,3 +8,9 @@ router = APIRouter(prefix="/api") @router.get("/health") def healthcheck() -> Response: return Response(status_code=200) + + +@router.get("/gpu-status") +def gpu_status() -> dict[str, bool]: + has_gpu = torch.cuda.is_available() + return {"gpu_available": has_gpu} diff --git a/backend/shared_configs/configs.py b/backend/shared_configs/configs.py index 67bb83463..c893c4357 100644 --- a/backend/shared_configs/configs.py +++ b/backend/shared_configs/configs.py @@ -27,9 +27,8 @@ ENABLE_RERANKING_ASYNC_FLOW = ( ENABLE_RERANKING_REAL_TIME_FLOW = ( os.environ.get("ENABLE_RERANKING_REAL_TIME_FLOW", "").lower() == "true" ) -# Only using one cross-encoder for now -CROSS_ENCODER_MODEL_ENSEMBLE = ["mixedbread-ai/mxbai-rerank-xsmall-v1"] -CROSS_EMBED_CONTEXT_SIZE = 512 + +DEFAULT_CROSS_ENCODER_MODEL_NAME = "mixedbread-ai/mxbai-rerank-xsmall-v1" # This controls the minimum number of pytorch "threads" to allocate to the embedding # model. If torch finds more threads on its own, this value is not used. diff --git a/backend/shared_configs/model_server_models.py b/backend/shared_configs/model_server_models.py index f00a1067a..aa024d7e8 100644 --- a/backend/shared_configs/model_server_models.py +++ b/backend/shared_configs/model_server_models.py @@ -25,10 +25,12 @@ class EmbedResponse(BaseModel): class RerankRequest(BaseModel): query: str documents: list[str] + model_name: str + api_key: str | None class RerankResponse(BaseModel): - scores: list[list[float] | None] + scores: list[float] class IntentRequest(BaseModel):