mirror of
https://github.com/danswer-ai/danswer.git
synced 2025-04-08 03:48:14 +02:00
Rework Rerankers (#2093)
This commit is contained in:
parent
7dcc42aa95
commit
8cd1eda8b1
@ -190,17 +190,26 @@ class EmbeddingModel:
|
||||
)
|
||||
|
||||
|
||||
class CrossEncoderEnsembleModel:
|
||||
class RerankingModel:
|
||||
def __init__(
|
||||
self,
|
||||
model_name: str,
|
||||
api_key: str | None,
|
||||
model_server_host: str = MODEL_SERVER_HOST,
|
||||
model_server_port: int = MODEL_SERVER_PORT,
|
||||
) -> None:
|
||||
model_server_url = build_model_server_url(model_server_host, model_server_port)
|
||||
self.rerank_server_endpoint = model_server_url + "/encoder/cross-encoder-scores"
|
||||
self.model_name = model_name
|
||||
self.api_key = api_key
|
||||
|
||||
def predict(self, query: str, passages: list[str]) -> list[list[float] | None]:
|
||||
rerank_request = RerankRequest(query=query, documents=passages)
|
||||
def predict(self, query: str, passages: list[str]) -> list[float]:
|
||||
rerank_request = RerankRequest(
|
||||
query=query,
|
||||
documents=passages,
|
||||
model_name=self.model_name,
|
||||
api_key=self.api_key,
|
||||
)
|
||||
|
||||
response = requests.post(
|
||||
self.rerank_server_endpoint, json=rerank_request.dict()
|
||||
|
@ -13,9 +13,7 @@ from danswer.document_index.document_index_utils import (
|
||||
translate_boost_count_to_multiplier,
|
||||
)
|
||||
from danswer.llm.interfaces import LLM
|
||||
from danswer.natural_language_processing.search_nlp_models import (
|
||||
CrossEncoderEnsembleModel,
|
||||
)
|
||||
from danswer.natural_language_processing.search_nlp_models import RerankingModel
|
||||
from danswer.search.enums import LLMEvaluationType
|
||||
from danswer.search.models import ChunkMetric
|
||||
from danswer.search.models import InferenceChunk
|
||||
@ -30,6 +28,7 @@ from danswer.utils.logger import setup_logger
|
||||
from danswer.utils.threadpool_concurrency import FunctionCall
|
||||
from danswer.utils.threadpool_concurrency import run_functions_in_parallel
|
||||
from danswer.utils.timing import log_function_time
|
||||
from shared_configs.configs import DEFAULT_CROSS_ENCODER_MODEL_NAME
|
||||
|
||||
|
||||
logger = setup_logger()
|
||||
@ -96,15 +95,20 @@ def semantic_reranking(
|
||||
|
||||
Note: this updates the chunks in place, it updates the chunk scores which came from retrieval
|
||||
"""
|
||||
cross_encoders = CrossEncoderEnsembleModel()
|
||||
# TODO update this
|
||||
cross_encoder = RerankingModel(
|
||||
model_name=DEFAULT_CROSS_ENCODER_MODEL_NAME,
|
||||
api_key=None,
|
||||
)
|
||||
|
||||
passages = [
|
||||
f"{chunk.semantic_identifier or chunk.title or ''}\n{chunk.content}"
|
||||
for chunk in chunks
|
||||
]
|
||||
sim_scores_floats = cross_encoders.predict(query=query, passages=passages)
|
||||
sim_scores_floats = cross_encoder.predict(query=query, passages=passages)
|
||||
|
||||
sim_scores = [numpy.array(scores) for scores in sim_scores_floats]
|
||||
# Old logic to handle multiple cross-encoders preserved but not used
|
||||
sim_scores = [numpy.array(sim_scores_floats)]
|
||||
|
||||
raw_sim_scores = cast(numpy.ndarray, sum(sim_scores) / len(sim_scores))
|
||||
|
||||
|
@ -1,4 +1,3 @@
|
||||
import gc
|
||||
import json
|
||||
from typing import Any
|
||||
from typing import Optional
|
||||
@ -25,8 +24,7 @@ from model_server.constants import EmbeddingModelTextType
|
||||
from model_server.constants import EmbeddingProvider
|
||||
from model_server.constants import MODEL_WARM_UP_STRING
|
||||
from model_server.utils import simple_log_function_time
|
||||
from shared_configs.configs import CROSS_EMBED_CONTEXT_SIZE
|
||||
from shared_configs.configs import CROSS_ENCODER_MODEL_ENSEMBLE
|
||||
from shared_configs.configs import DEFAULT_CROSS_ENCODER_MODEL_NAME
|
||||
from shared_configs.configs import INDEXING_ONLY
|
||||
from shared_configs.enums import EmbedTextType
|
||||
from shared_configs.model_server_models import Embedding
|
||||
@ -42,7 +40,8 @@ logger = setup_logger()
|
||||
router = APIRouter(prefix="/encoder")
|
||||
|
||||
_GLOBAL_MODELS_DICT: dict[str, "SentenceTransformer"] = {}
|
||||
_RERANK_MODELS: Optional[list["CrossEncoder"]] = None
|
||||
_RERANK_MODEL: Optional["CrossEncoder"] = None
|
||||
|
||||
# If we are not only indexing, dont want retry very long
|
||||
_RETRY_DELAY = 10 if INDEXING_ONLY else 0.1
|
||||
_RETRY_TRIES = 10 if INDEXING_ONLY else 2
|
||||
@ -229,32 +228,22 @@ def get_embedding_model(
|
||||
return _GLOBAL_MODELS_DICT[model_name]
|
||||
|
||||
|
||||
def get_local_reranking_model_ensemble(
|
||||
model_names: list[str] = CROSS_ENCODER_MODEL_ENSEMBLE,
|
||||
max_context_length: int = CROSS_EMBED_CONTEXT_SIZE,
|
||||
) -> list[CrossEncoder]:
|
||||
global _RERANK_MODELS
|
||||
if _RERANK_MODELS is None or max_context_length != _RERANK_MODELS[0].max_length:
|
||||
del _RERANK_MODELS
|
||||
gc.collect()
|
||||
|
||||
_RERANK_MODELS = []
|
||||
for model_name in model_names:
|
||||
logger.info(f"Loading {model_name}")
|
||||
model = CrossEncoder(model_name)
|
||||
model.max_length = max_context_length
|
||||
_RERANK_MODELS.append(model)
|
||||
return _RERANK_MODELS
|
||||
def get_local_reranking_model(
|
||||
model_name: str = DEFAULT_CROSS_ENCODER_MODEL_NAME,
|
||||
) -> CrossEncoder:
|
||||
global _RERANK_MODEL
|
||||
if _RERANK_MODEL is None:
|
||||
logger.info(f"Loading {model_name}")
|
||||
model = CrossEncoder(model_name)
|
||||
_RERANK_MODEL = model
|
||||
return _RERANK_MODEL
|
||||
|
||||
|
||||
def warm_up_cross_encoders() -> None:
|
||||
logger.info(f"Warming up Cross-Encoders: {CROSS_ENCODER_MODEL_ENSEMBLE}")
|
||||
def warm_up_cross_encoder() -> None:
|
||||
logger.info(f"Warming up Cross-Encoder: {DEFAULT_CROSS_ENCODER_MODEL_NAME}")
|
||||
|
||||
cross_encoders = get_local_reranking_model_ensemble()
|
||||
[
|
||||
cross_encoder.predict((MODEL_WARM_UP_STRING, MODEL_WARM_UP_STRING))
|
||||
for cross_encoder in cross_encoders
|
||||
]
|
||||
cross_encoder = get_local_reranking_model()
|
||||
cross_encoder.predict((MODEL_WARM_UP_STRING, MODEL_WARM_UP_STRING))
|
||||
|
||||
|
||||
@simple_log_function_time()
|
||||
@ -325,13 +314,9 @@ def embed_text(
|
||||
|
||||
|
||||
@simple_log_function_time()
|
||||
def calc_sim_scores(query: str, docs: list[str]) -> list[list[float] | None]:
|
||||
cross_encoders = get_local_reranking_model_ensemble()
|
||||
sim_scores = [
|
||||
encoder.predict([(query, doc) for doc in docs]).tolist() # type: ignore
|
||||
for encoder in cross_encoders
|
||||
]
|
||||
return sim_scores
|
||||
def calc_sim_scores(query: str, docs: list[str]) -> list[float]:
|
||||
cross_encoder = get_local_reranking_model()
|
||||
return cross_encoder.predict([(query, doc) for doc in docs]).tolist() # type: ignore
|
||||
|
||||
|
||||
@router.post("/bi-encoder-embed")
|
||||
|
@ -15,7 +15,7 @@ from danswer.utils.logger import setup_logger
|
||||
from model_server.custom_models import router as custom_models_router
|
||||
from model_server.custom_models import warm_up_intent_model
|
||||
from model_server.encoders import router as encoders_router
|
||||
from model_server.encoders import warm_up_cross_encoders
|
||||
from model_server.encoders import warm_up_cross_encoder
|
||||
from model_server.management_endpoints import router as management_router
|
||||
from shared_configs.configs import ENABLE_RERANKING_ASYNC_FLOW
|
||||
from shared_configs.configs import ENABLE_RERANKING_REAL_TIME_FLOW
|
||||
@ -65,7 +65,7 @@ async def lifespan(app: FastAPI) -> AsyncGenerator:
|
||||
if not INDEXING_ONLY:
|
||||
warm_up_intent_model()
|
||||
if ENABLE_RERANKING_REAL_TIME_FLOW or ENABLE_RERANKING_ASYNC_FLOW:
|
||||
warm_up_cross_encoders()
|
||||
warm_up_cross_encoder()
|
||||
else:
|
||||
logger.info("This model server should only run document indexing.")
|
||||
|
||||
|
@ -1,3 +1,4 @@
|
||||
import torch
|
||||
from fastapi import APIRouter
|
||||
from fastapi import Response
|
||||
|
||||
@ -7,3 +8,9 @@ router = APIRouter(prefix="/api")
|
||||
@router.get("/health")
|
||||
def healthcheck() -> Response:
|
||||
return Response(status_code=200)
|
||||
|
||||
|
||||
@router.get("/gpu-status")
|
||||
def gpu_status() -> dict[str, bool]:
|
||||
has_gpu = torch.cuda.is_available()
|
||||
return {"gpu_available": has_gpu}
|
||||
|
@ -27,9 +27,8 @@ ENABLE_RERANKING_ASYNC_FLOW = (
|
||||
ENABLE_RERANKING_REAL_TIME_FLOW = (
|
||||
os.environ.get("ENABLE_RERANKING_REAL_TIME_FLOW", "").lower() == "true"
|
||||
)
|
||||
# Only using one cross-encoder for now
|
||||
CROSS_ENCODER_MODEL_ENSEMBLE = ["mixedbread-ai/mxbai-rerank-xsmall-v1"]
|
||||
CROSS_EMBED_CONTEXT_SIZE = 512
|
||||
|
||||
DEFAULT_CROSS_ENCODER_MODEL_NAME = "mixedbread-ai/mxbai-rerank-xsmall-v1"
|
||||
|
||||
# This controls the minimum number of pytorch "threads" to allocate to the embedding
|
||||
# model. If torch finds more threads on its own, this value is not used.
|
||||
|
@ -25,10 +25,12 @@ class EmbedResponse(BaseModel):
|
||||
class RerankRequest(BaseModel):
|
||||
query: str
|
||||
documents: list[str]
|
||||
model_name: str
|
||||
api_key: str | None
|
||||
|
||||
|
||||
class RerankResponse(BaseModel):
|
||||
scores: list[list[float] | None]
|
||||
scores: list[float]
|
||||
|
||||
|
||||
class IntentRequest(BaseModel):
|
||||
|
Loading…
x
Reference in New Issue
Block a user