mirror of
https://github.com/danswer-ai/danswer.git
synced 2025-09-27 12:29:41 +02:00
Rework Rerankers (#2093)
This commit is contained in:
@@ -190,17 +190,26 @@ class EmbeddingModel:
|
|||||||
)
|
)
|
||||||
|
|
||||||
|
|
||||||
class CrossEncoderEnsembleModel:
|
class RerankingModel:
|
||||||
def __init__(
|
def __init__(
|
||||||
self,
|
self,
|
||||||
|
model_name: str,
|
||||||
|
api_key: str | None,
|
||||||
model_server_host: str = MODEL_SERVER_HOST,
|
model_server_host: str = MODEL_SERVER_HOST,
|
||||||
model_server_port: int = MODEL_SERVER_PORT,
|
model_server_port: int = MODEL_SERVER_PORT,
|
||||||
) -> None:
|
) -> None:
|
||||||
model_server_url = build_model_server_url(model_server_host, model_server_port)
|
model_server_url = build_model_server_url(model_server_host, model_server_port)
|
||||||
self.rerank_server_endpoint = model_server_url + "/encoder/cross-encoder-scores"
|
self.rerank_server_endpoint = model_server_url + "/encoder/cross-encoder-scores"
|
||||||
|
self.model_name = model_name
|
||||||
|
self.api_key = api_key
|
||||||
|
|
||||||
def predict(self, query: str, passages: list[str]) -> list[list[float] | None]:
|
def predict(self, query: str, passages: list[str]) -> list[float]:
|
||||||
rerank_request = RerankRequest(query=query, documents=passages)
|
rerank_request = RerankRequest(
|
||||||
|
query=query,
|
||||||
|
documents=passages,
|
||||||
|
model_name=self.model_name,
|
||||||
|
api_key=self.api_key,
|
||||||
|
)
|
||||||
|
|
||||||
response = requests.post(
|
response = requests.post(
|
||||||
self.rerank_server_endpoint, json=rerank_request.dict()
|
self.rerank_server_endpoint, json=rerank_request.dict()
|
||||||
|
@@ -13,9 +13,7 @@ from danswer.document_index.document_index_utils import (
|
|||||||
translate_boost_count_to_multiplier,
|
translate_boost_count_to_multiplier,
|
||||||
)
|
)
|
||||||
from danswer.llm.interfaces import LLM
|
from danswer.llm.interfaces import LLM
|
||||||
from danswer.natural_language_processing.search_nlp_models import (
|
from danswer.natural_language_processing.search_nlp_models import RerankingModel
|
||||||
CrossEncoderEnsembleModel,
|
|
||||||
)
|
|
||||||
from danswer.search.enums import LLMEvaluationType
|
from danswer.search.enums import LLMEvaluationType
|
||||||
from danswer.search.models import ChunkMetric
|
from danswer.search.models import ChunkMetric
|
||||||
from danswer.search.models import InferenceChunk
|
from danswer.search.models import InferenceChunk
|
||||||
@@ -30,6 +28,7 @@ from danswer.utils.logger import setup_logger
|
|||||||
from danswer.utils.threadpool_concurrency import FunctionCall
|
from danswer.utils.threadpool_concurrency import FunctionCall
|
||||||
from danswer.utils.threadpool_concurrency import run_functions_in_parallel
|
from danswer.utils.threadpool_concurrency import run_functions_in_parallel
|
||||||
from danswer.utils.timing import log_function_time
|
from danswer.utils.timing import log_function_time
|
||||||
|
from shared_configs.configs import DEFAULT_CROSS_ENCODER_MODEL_NAME
|
||||||
|
|
||||||
|
|
||||||
logger = setup_logger()
|
logger = setup_logger()
|
||||||
@@ -96,15 +95,20 @@ def semantic_reranking(
|
|||||||
|
|
||||||
Note: this updates the chunks in place, it updates the chunk scores which came from retrieval
|
Note: this updates the chunks in place, it updates the chunk scores which came from retrieval
|
||||||
"""
|
"""
|
||||||
cross_encoders = CrossEncoderEnsembleModel()
|
# TODO update this
|
||||||
|
cross_encoder = RerankingModel(
|
||||||
|
model_name=DEFAULT_CROSS_ENCODER_MODEL_NAME,
|
||||||
|
api_key=None,
|
||||||
|
)
|
||||||
|
|
||||||
passages = [
|
passages = [
|
||||||
f"{chunk.semantic_identifier or chunk.title or ''}\n{chunk.content}"
|
f"{chunk.semantic_identifier or chunk.title or ''}\n{chunk.content}"
|
||||||
for chunk in chunks
|
for chunk in chunks
|
||||||
]
|
]
|
||||||
sim_scores_floats = cross_encoders.predict(query=query, passages=passages)
|
sim_scores_floats = cross_encoder.predict(query=query, passages=passages)
|
||||||
|
|
||||||
sim_scores = [numpy.array(scores) for scores in sim_scores_floats]
|
# Old logic to handle multiple cross-encoders preserved but not used
|
||||||
|
sim_scores = [numpy.array(sim_scores_floats)]
|
||||||
|
|
||||||
raw_sim_scores = cast(numpy.ndarray, sum(sim_scores) / len(sim_scores))
|
raw_sim_scores = cast(numpy.ndarray, sum(sim_scores) / len(sim_scores))
|
||||||
|
|
||||||
|
@@ -1,4 +1,3 @@
|
|||||||
import gc
|
|
||||||
import json
|
import json
|
||||||
from typing import Any
|
from typing import Any
|
||||||
from typing import Optional
|
from typing import Optional
|
||||||
@@ -25,8 +24,7 @@ from model_server.constants import EmbeddingModelTextType
|
|||||||
from model_server.constants import EmbeddingProvider
|
from model_server.constants import EmbeddingProvider
|
||||||
from model_server.constants import MODEL_WARM_UP_STRING
|
from model_server.constants import MODEL_WARM_UP_STRING
|
||||||
from model_server.utils import simple_log_function_time
|
from model_server.utils import simple_log_function_time
|
||||||
from shared_configs.configs import CROSS_EMBED_CONTEXT_SIZE
|
from shared_configs.configs import DEFAULT_CROSS_ENCODER_MODEL_NAME
|
||||||
from shared_configs.configs import CROSS_ENCODER_MODEL_ENSEMBLE
|
|
||||||
from shared_configs.configs import INDEXING_ONLY
|
from shared_configs.configs import INDEXING_ONLY
|
||||||
from shared_configs.enums import EmbedTextType
|
from shared_configs.enums import EmbedTextType
|
||||||
from shared_configs.model_server_models import Embedding
|
from shared_configs.model_server_models import Embedding
|
||||||
@@ -42,7 +40,8 @@ logger = setup_logger()
|
|||||||
router = APIRouter(prefix="/encoder")
|
router = APIRouter(prefix="/encoder")
|
||||||
|
|
||||||
_GLOBAL_MODELS_DICT: dict[str, "SentenceTransformer"] = {}
|
_GLOBAL_MODELS_DICT: dict[str, "SentenceTransformer"] = {}
|
||||||
_RERANK_MODELS: Optional[list["CrossEncoder"]] = None
|
_RERANK_MODEL: Optional["CrossEncoder"] = None
|
||||||
|
|
||||||
# If we are not only indexing, dont want retry very long
|
# If we are not only indexing, dont want retry very long
|
||||||
_RETRY_DELAY = 10 if INDEXING_ONLY else 0.1
|
_RETRY_DELAY = 10 if INDEXING_ONLY else 0.1
|
||||||
_RETRY_TRIES = 10 if INDEXING_ONLY else 2
|
_RETRY_TRIES = 10 if INDEXING_ONLY else 2
|
||||||
@@ -229,32 +228,22 @@ def get_embedding_model(
|
|||||||
return _GLOBAL_MODELS_DICT[model_name]
|
return _GLOBAL_MODELS_DICT[model_name]
|
||||||
|
|
||||||
|
|
||||||
def get_local_reranking_model_ensemble(
|
def get_local_reranking_model(
|
||||||
model_names: list[str] = CROSS_ENCODER_MODEL_ENSEMBLE,
|
model_name: str = DEFAULT_CROSS_ENCODER_MODEL_NAME,
|
||||||
max_context_length: int = CROSS_EMBED_CONTEXT_SIZE,
|
) -> CrossEncoder:
|
||||||
) -> list[CrossEncoder]:
|
global _RERANK_MODEL
|
||||||
global _RERANK_MODELS
|
if _RERANK_MODEL is None:
|
||||||
if _RERANK_MODELS is None or max_context_length != _RERANK_MODELS[0].max_length:
|
logger.info(f"Loading {model_name}")
|
||||||
del _RERANK_MODELS
|
model = CrossEncoder(model_name)
|
||||||
gc.collect()
|
_RERANK_MODEL = model
|
||||||
|
return _RERANK_MODEL
|
||||||
_RERANK_MODELS = []
|
|
||||||
for model_name in model_names:
|
|
||||||
logger.info(f"Loading {model_name}")
|
|
||||||
model = CrossEncoder(model_name)
|
|
||||||
model.max_length = max_context_length
|
|
||||||
_RERANK_MODELS.append(model)
|
|
||||||
return _RERANK_MODELS
|
|
||||||
|
|
||||||
|
|
||||||
def warm_up_cross_encoders() -> None:
|
def warm_up_cross_encoder() -> None:
|
||||||
logger.info(f"Warming up Cross-Encoders: {CROSS_ENCODER_MODEL_ENSEMBLE}")
|
logger.info(f"Warming up Cross-Encoder: {DEFAULT_CROSS_ENCODER_MODEL_NAME}")
|
||||||
|
|
||||||
cross_encoders = get_local_reranking_model_ensemble()
|
cross_encoder = get_local_reranking_model()
|
||||||
[
|
cross_encoder.predict((MODEL_WARM_UP_STRING, MODEL_WARM_UP_STRING))
|
||||||
cross_encoder.predict((MODEL_WARM_UP_STRING, MODEL_WARM_UP_STRING))
|
|
||||||
for cross_encoder in cross_encoders
|
|
||||||
]
|
|
||||||
|
|
||||||
|
|
||||||
@simple_log_function_time()
|
@simple_log_function_time()
|
||||||
@@ -325,13 +314,9 @@ def embed_text(
|
|||||||
|
|
||||||
|
|
||||||
@simple_log_function_time()
|
@simple_log_function_time()
|
||||||
def calc_sim_scores(query: str, docs: list[str]) -> list[list[float] | None]:
|
def calc_sim_scores(query: str, docs: list[str]) -> list[float]:
|
||||||
cross_encoders = get_local_reranking_model_ensemble()
|
cross_encoder = get_local_reranking_model()
|
||||||
sim_scores = [
|
return cross_encoder.predict([(query, doc) for doc in docs]).tolist() # type: ignore
|
||||||
encoder.predict([(query, doc) for doc in docs]).tolist() # type: ignore
|
|
||||||
for encoder in cross_encoders
|
|
||||||
]
|
|
||||||
return sim_scores
|
|
||||||
|
|
||||||
|
|
||||||
@router.post("/bi-encoder-embed")
|
@router.post("/bi-encoder-embed")
|
||||||
|
@@ -15,7 +15,7 @@ from danswer.utils.logger import setup_logger
|
|||||||
from model_server.custom_models import router as custom_models_router
|
from model_server.custom_models import router as custom_models_router
|
||||||
from model_server.custom_models import warm_up_intent_model
|
from model_server.custom_models import warm_up_intent_model
|
||||||
from model_server.encoders import router as encoders_router
|
from model_server.encoders import router as encoders_router
|
||||||
from model_server.encoders import warm_up_cross_encoders
|
from model_server.encoders import warm_up_cross_encoder
|
||||||
from model_server.management_endpoints import router as management_router
|
from model_server.management_endpoints import router as management_router
|
||||||
from shared_configs.configs import ENABLE_RERANKING_ASYNC_FLOW
|
from shared_configs.configs import ENABLE_RERANKING_ASYNC_FLOW
|
||||||
from shared_configs.configs import ENABLE_RERANKING_REAL_TIME_FLOW
|
from shared_configs.configs import ENABLE_RERANKING_REAL_TIME_FLOW
|
||||||
@@ -65,7 +65,7 @@ async def lifespan(app: FastAPI) -> AsyncGenerator:
|
|||||||
if not INDEXING_ONLY:
|
if not INDEXING_ONLY:
|
||||||
warm_up_intent_model()
|
warm_up_intent_model()
|
||||||
if ENABLE_RERANKING_REAL_TIME_FLOW or ENABLE_RERANKING_ASYNC_FLOW:
|
if ENABLE_RERANKING_REAL_TIME_FLOW or ENABLE_RERANKING_ASYNC_FLOW:
|
||||||
warm_up_cross_encoders()
|
warm_up_cross_encoder()
|
||||||
else:
|
else:
|
||||||
logger.info("This model server should only run document indexing.")
|
logger.info("This model server should only run document indexing.")
|
||||||
|
|
||||||
|
@@ -1,3 +1,4 @@
|
|||||||
|
import torch
|
||||||
from fastapi import APIRouter
|
from fastapi import APIRouter
|
||||||
from fastapi import Response
|
from fastapi import Response
|
||||||
|
|
||||||
@@ -7,3 +8,9 @@ router = APIRouter(prefix="/api")
|
|||||||
@router.get("/health")
|
@router.get("/health")
|
||||||
def healthcheck() -> Response:
|
def healthcheck() -> Response:
|
||||||
return Response(status_code=200)
|
return Response(status_code=200)
|
||||||
|
|
||||||
|
|
||||||
|
@router.get("/gpu-status")
|
||||||
|
def gpu_status() -> dict[str, bool]:
|
||||||
|
has_gpu = torch.cuda.is_available()
|
||||||
|
return {"gpu_available": has_gpu}
|
||||||
|
@@ -27,9 +27,8 @@ ENABLE_RERANKING_ASYNC_FLOW = (
|
|||||||
ENABLE_RERANKING_REAL_TIME_FLOW = (
|
ENABLE_RERANKING_REAL_TIME_FLOW = (
|
||||||
os.environ.get("ENABLE_RERANKING_REAL_TIME_FLOW", "").lower() == "true"
|
os.environ.get("ENABLE_RERANKING_REAL_TIME_FLOW", "").lower() == "true"
|
||||||
)
|
)
|
||||||
# Only using one cross-encoder for now
|
|
||||||
CROSS_ENCODER_MODEL_ENSEMBLE = ["mixedbread-ai/mxbai-rerank-xsmall-v1"]
|
DEFAULT_CROSS_ENCODER_MODEL_NAME = "mixedbread-ai/mxbai-rerank-xsmall-v1"
|
||||||
CROSS_EMBED_CONTEXT_SIZE = 512
|
|
||||||
|
|
||||||
# This controls the minimum number of pytorch "threads" to allocate to the embedding
|
# This controls the minimum number of pytorch "threads" to allocate to the embedding
|
||||||
# model. If torch finds more threads on its own, this value is not used.
|
# model. If torch finds more threads on its own, this value is not used.
|
||||||
|
@@ -25,10 +25,12 @@ class EmbedResponse(BaseModel):
|
|||||||
class RerankRequest(BaseModel):
|
class RerankRequest(BaseModel):
|
||||||
query: str
|
query: str
|
||||||
documents: list[str]
|
documents: list[str]
|
||||||
|
model_name: str
|
||||||
|
api_key: str | None
|
||||||
|
|
||||||
|
|
||||||
class RerankResponse(BaseModel):
|
class RerankResponse(BaseModel):
|
||||||
scores: list[list[float] | None]
|
scores: list[float]
|
||||||
|
|
||||||
|
|
||||||
class IntentRequest(BaseModel):
|
class IntentRequest(BaseModel):
|
||||||
|
Reference in New Issue
Block a user