Rework Rerankers (#2093)

This commit is contained in:
Yuhong Sun 2024-08-08 21:33:49 -07:00 committed by GitHub
parent 7dcc42aa95
commit 8cd1eda8b1
No known key found for this signature in database
GPG Key ID: B5690EEEBB952194
7 changed files with 55 additions and 49 deletions

View File

@ -190,17 +190,26 @@ class EmbeddingModel:
)
class CrossEncoderEnsembleModel:
class RerankingModel:
def __init__(
self,
model_name: str,
api_key: str | None,
model_server_host: str = MODEL_SERVER_HOST,
model_server_port: int = MODEL_SERVER_PORT,
) -> None:
model_server_url = build_model_server_url(model_server_host, model_server_port)
self.rerank_server_endpoint = model_server_url + "/encoder/cross-encoder-scores"
self.model_name = model_name
self.api_key = api_key
def predict(self, query: str, passages: list[str]) -> list[list[float] | None]:
rerank_request = RerankRequest(query=query, documents=passages)
def predict(self, query: str, passages: list[str]) -> list[float]:
rerank_request = RerankRequest(
query=query,
documents=passages,
model_name=self.model_name,
api_key=self.api_key,
)
response = requests.post(
self.rerank_server_endpoint, json=rerank_request.dict()

View File

@ -13,9 +13,7 @@ from danswer.document_index.document_index_utils import (
translate_boost_count_to_multiplier,
)
from danswer.llm.interfaces import LLM
from danswer.natural_language_processing.search_nlp_models import (
CrossEncoderEnsembleModel,
)
from danswer.natural_language_processing.search_nlp_models import RerankingModel
from danswer.search.enums import LLMEvaluationType
from danswer.search.models import ChunkMetric
from danswer.search.models import InferenceChunk
@ -30,6 +28,7 @@ from danswer.utils.logger import setup_logger
from danswer.utils.threadpool_concurrency import FunctionCall
from danswer.utils.threadpool_concurrency import run_functions_in_parallel
from danswer.utils.timing import log_function_time
from shared_configs.configs import DEFAULT_CROSS_ENCODER_MODEL_NAME
logger = setup_logger()
@ -96,15 +95,20 @@ def semantic_reranking(
Note: this updates the chunks in place, it updates the chunk scores which came from retrieval
"""
cross_encoders = CrossEncoderEnsembleModel()
# TODO update this
cross_encoder = RerankingModel(
model_name=DEFAULT_CROSS_ENCODER_MODEL_NAME,
api_key=None,
)
passages = [
f"{chunk.semantic_identifier or chunk.title or ''}\n{chunk.content}"
for chunk in chunks
]
sim_scores_floats = cross_encoders.predict(query=query, passages=passages)
sim_scores_floats = cross_encoder.predict(query=query, passages=passages)
sim_scores = [numpy.array(scores) for scores in sim_scores_floats]
# Old logic to handle multiple cross-encoders preserved but not used
sim_scores = [numpy.array(sim_scores_floats)]
raw_sim_scores = cast(numpy.ndarray, sum(sim_scores) / len(sim_scores))

View File

@ -1,4 +1,3 @@
import gc
import json
from typing import Any
from typing import Optional
@ -25,8 +24,7 @@ from model_server.constants import EmbeddingModelTextType
from model_server.constants import EmbeddingProvider
from model_server.constants import MODEL_WARM_UP_STRING
from model_server.utils import simple_log_function_time
from shared_configs.configs import CROSS_EMBED_CONTEXT_SIZE
from shared_configs.configs import CROSS_ENCODER_MODEL_ENSEMBLE
from shared_configs.configs import DEFAULT_CROSS_ENCODER_MODEL_NAME
from shared_configs.configs import INDEXING_ONLY
from shared_configs.enums import EmbedTextType
from shared_configs.model_server_models import Embedding
@ -42,7 +40,8 @@ logger = setup_logger()
router = APIRouter(prefix="/encoder")
_GLOBAL_MODELS_DICT: dict[str, "SentenceTransformer"] = {}
_RERANK_MODELS: Optional[list["CrossEncoder"]] = None
_RERANK_MODEL: Optional["CrossEncoder"] = None
# If we are not only indexing, dont want retry very long
_RETRY_DELAY = 10 if INDEXING_ONLY else 0.1
_RETRY_TRIES = 10 if INDEXING_ONLY else 2
@ -229,32 +228,22 @@ def get_embedding_model(
return _GLOBAL_MODELS_DICT[model_name]
def get_local_reranking_model_ensemble(
model_names: list[str] = CROSS_ENCODER_MODEL_ENSEMBLE,
max_context_length: int = CROSS_EMBED_CONTEXT_SIZE,
) -> list[CrossEncoder]:
global _RERANK_MODELS
if _RERANK_MODELS is None or max_context_length != _RERANK_MODELS[0].max_length:
del _RERANK_MODELS
gc.collect()
_RERANK_MODELS = []
for model_name in model_names:
logger.info(f"Loading {model_name}")
model = CrossEncoder(model_name)
model.max_length = max_context_length
_RERANK_MODELS.append(model)
return _RERANK_MODELS
def get_local_reranking_model(
model_name: str = DEFAULT_CROSS_ENCODER_MODEL_NAME,
) -> CrossEncoder:
global _RERANK_MODEL
if _RERANK_MODEL is None:
logger.info(f"Loading {model_name}")
model = CrossEncoder(model_name)
_RERANK_MODEL = model
return _RERANK_MODEL
def warm_up_cross_encoders() -> None:
logger.info(f"Warming up Cross-Encoders: {CROSS_ENCODER_MODEL_ENSEMBLE}")
def warm_up_cross_encoder() -> None:
logger.info(f"Warming up Cross-Encoder: {DEFAULT_CROSS_ENCODER_MODEL_NAME}")
cross_encoders = get_local_reranking_model_ensemble()
[
cross_encoder.predict((MODEL_WARM_UP_STRING, MODEL_WARM_UP_STRING))
for cross_encoder in cross_encoders
]
cross_encoder = get_local_reranking_model()
cross_encoder.predict((MODEL_WARM_UP_STRING, MODEL_WARM_UP_STRING))
@simple_log_function_time()
@ -325,13 +314,9 @@ def embed_text(
@simple_log_function_time()
def calc_sim_scores(query: str, docs: list[str]) -> list[list[float] | None]:
cross_encoders = get_local_reranking_model_ensemble()
sim_scores = [
encoder.predict([(query, doc) for doc in docs]).tolist() # type: ignore
for encoder in cross_encoders
]
return sim_scores
def calc_sim_scores(query: str, docs: list[str]) -> list[float]:
cross_encoder = get_local_reranking_model()
return cross_encoder.predict([(query, doc) for doc in docs]).tolist() # type: ignore
@router.post("/bi-encoder-embed")

View File

@ -15,7 +15,7 @@ from danswer.utils.logger import setup_logger
from model_server.custom_models import router as custom_models_router
from model_server.custom_models import warm_up_intent_model
from model_server.encoders import router as encoders_router
from model_server.encoders import warm_up_cross_encoders
from model_server.encoders import warm_up_cross_encoder
from model_server.management_endpoints import router as management_router
from shared_configs.configs import ENABLE_RERANKING_ASYNC_FLOW
from shared_configs.configs import ENABLE_RERANKING_REAL_TIME_FLOW
@ -65,7 +65,7 @@ async def lifespan(app: FastAPI) -> AsyncGenerator:
if not INDEXING_ONLY:
warm_up_intent_model()
if ENABLE_RERANKING_REAL_TIME_FLOW or ENABLE_RERANKING_ASYNC_FLOW:
warm_up_cross_encoders()
warm_up_cross_encoder()
else:
logger.info("This model server should only run document indexing.")

View File

@ -1,3 +1,4 @@
import torch
from fastapi import APIRouter
from fastapi import Response
@ -7,3 +8,9 @@ router = APIRouter(prefix="/api")
@router.get("/health")
def healthcheck() -> Response:
return Response(status_code=200)
@router.get("/gpu-status")
def gpu_status() -> dict[str, bool]:
has_gpu = torch.cuda.is_available()
return {"gpu_available": has_gpu}

View File

@ -27,9 +27,8 @@ ENABLE_RERANKING_ASYNC_FLOW = (
ENABLE_RERANKING_REAL_TIME_FLOW = (
os.environ.get("ENABLE_RERANKING_REAL_TIME_FLOW", "").lower() == "true"
)
# Only using one cross-encoder for now
CROSS_ENCODER_MODEL_ENSEMBLE = ["mixedbread-ai/mxbai-rerank-xsmall-v1"]
CROSS_EMBED_CONTEXT_SIZE = 512
DEFAULT_CROSS_ENCODER_MODEL_NAME = "mixedbread-ai/mxbai-rerank-xsmall-v1"
# This controls the minimum number of pytorch "threads" to allocate to the embedding
# model. If torch finds more threads on its own, this value is not used.

View File

@ -25,10 +25,12 @@ class EmbedResponse(BaseModel):
class RerankRequest(BaseModel):
query: str
documents: list[str]
model_name: str
api_key: str | None
class RerankResponse(BaseModel):
scores: list[list[float] | None]
scores: list[float]
class IntentRequest(BaseModel):