mirror of
https://github.com/danswer-ai/danswer.git
synced 2025-07-12 14:12:53 +02:00
Cohere Rerank (#2109)
This commit is contained in:
@ -22,11 +22,10 @@ from model_server.constants import DEFAULT_VERTEX_MODEL
|
||||
from model_server.constants import DEFAULT_VOYAGE_MODEL
|
||||
from model_server.constants import EmbeddingModelTextType
|
||||
from model_server.constants import EmbeddingProvider
|
||||
from model_server.constants import MODEL_WARM_UP_STRING
|
||||
from model_server.utils import simple_log_function_time
|
||||
from shared_configs.configs import DEFAULT_CROSS_ENCODER_MODEL_NAME
|
||||
from shared_configs.configs import INDEXING_ONLY
|
||||
from shared_configs.enums import EmbedTextType
|
||||
from shared_configs.enums import RerankerProvider
|
||||
from shared_configs.model_server_models import Embedding
|
||||
from shared_configs.model_server_models import EmbedRequest
|
||||
from shared_configs.model_server_models import EmbedResponse
|
||||
@ -226,7 +225,7 @@ def get_embedding_model(
|
||||
|
||||
|
||||
def get_local_reranking_model(
|
||||
model_name: str = DEFAULT_CROSS_ENCODER_MODEL_NAME,
|
||||
model_name: str,
|
||||
) -> CrossEncoder:
|
||||
global _RERANK_MODEL
|
||||
if _RERANK_MODEL is None:
|
||||
@ -236,13 +235,6 @@ def get_local_reranking_model(
|
||||
return _RERANK_MODEL
|
||||
|
||||
|
||||
def warm_up_cross_encoder() -> None:
|
||||
logger.info(f"Warming up Cross-Encoder: {DEFAULT_CROSS_ENCODER_MODEL_NAME}")
|
||||
|
||||
cross_encoder = get_local_reranking_model()
|
||||
cross_encoder.predict((MODEL_WARM_UP_STRING, MODEL_WARM_UP_STRING))
|
||||
|
||||
|
||||
@simple_log_function_time()
|
||||
def embed_text(
|
||||
texts: list[str],
|
||||
@ -311,11 +303,21 @@ def embed_text(
|
||||
|
||||
|
||||
@simple_log_function_time()
|
||||
def calc_sim_scores(query: str, docs: list[str]) -> list[float]:
|
||||
cross_encoder = get_local_reranking_model()
|
||||
def local_rerank(query: str, docs: list[str], model_name: str) -> list[float]:
|
||||
cross_encoder = get_local_reranking_model(model_name)
|
||||
return cross_encoder.predict([(query, doc) for doc in docs]).tolist() # type: ignore
|
||||
|
||||
|
||||
def cohere_rerank(
|
||||
query: str, docs: list[str], model_name: str, api_key: str
|
||||
) -> list[float]:
|
||||
cohere_client = CohereClient(api_key=api_key)
|
||||
response = cohere_client.rerank(query=query, documents=docs, model=model_name)
|
||||
results = response.results
|
||||
sorted_results = sorted(results, key=lambda item: item.index)
|
||||
return [result.relevance_score for result in sorted_results]
|
||||
|
||||
|
||||
@router.post("/bi-encoder-embed")
|
||||
async def process_embed_request(
|
||||
embed_request: EmbedRequest,
|
||||
@ -351,23 +353,38 @@ async def process_embed_request(
|
||||
|
||||
|
||||
@router.post("/cross-encoder-scores")
|
||||
async def process_rerank_request(embed_request: RerankRequest) -> RerankResponse:
|
||||
async def process_rerank_request(rerank_request: RerankRequest) -> RerankResponse:
|
||||
"""Cross encoders can be purely black box from the app perspective"""
|
||||
if INDEXING_ONLY:
|
||||
raise RuntimeError("Indexing model server should not call intent endpoint")
|
||||
|
||||
if not embed_request.documents or not embed_request.query:
|
||||
if not rerank_request.documents or not rerank_request.query:
|
||||
raise HTTPException(
|
||||
status_code=400, detail="Missing documents or query for reranking"
|
||||
)
|
||||
if not all(embed_request.documents):
|
||||
if not all(rerank_request.documents):
|
||||
raise ValueError("Empty documents cannot be reranked.")
|
||||
|
||||
try:
|
||||
sim_scores = calc_sim_scores(
|
||||
query=embed_request.query, docs=embed_request.documents
|
||||
)
|
||||
return RerankResponse(scores=sim_scores)
|
||||
if rerank_request.provider_type is None:
|
||||
sim_scores = local_rerank(
|
||||
query=rerank_request.query,
|
||||
docs=rerank_request.documents,
|
||||
model_name=rerank_request.model_name,
|
||||
)
|
||||
return RerankResponse(scores=sim_scores)
|
||||
elif rerank_request.provider_type == RerankerProvider.COHERE:
|
||||
if rerank_request.api_key is None:
|
||||
raise RuntimeError("Cohere Rerank Requires an API Key")
|
||||
sim_scores = cohere_rerank(
|
||||
query=rerank_request.query,
|
||||
docs=rerank_request.documents,
|
||||
model_name=rerank_request.model_name,
|
||||
api_key=rerank_request.api_key,
|
||||
)
|
||||
return RerankResponse(scores=sim_scores)
|
||||
else:
|
||||
raise ValueError(f"Unsupported provider: {rerank_request.provider_type}")
|
||||
except Exception as e:
|
||||
logger.exception(f"Error during reranking process:\n{str(e)}")
|
||||
raise HTTPException(
|
||||
|
@ -15,10 +15,7 @@ from danswer.utils.logger import setup_logger
|
||||
from model_server.custom_models import router as custom_models_router
|
||||
from model_server.custom_models import warm_up_intent_model
|
||||
from model_server.encoders import router as encoders_router
|
||||
from model_server.encoders import warm_up_cross_encoder
|
||||
from model_server.management_endpoints import router as management_router
|
||||
from shared_configs.configs import ENABLE_RERANKING_ASYNC_FLOW
|
||||
from shared_configs.configs import ENABLE_RERANKING_REAL_TIME_FLOW
|
||||
from shared_configs.configs import INDEXING_ONLY
|
||||
from shared_configs.configs import MIN_THREADS_ML_MODELS
|
||||
from shared_configs.configs import MODEL_SERVER_ALLOWED_HOST
|
||||
@ -64,8 +61,6 @@ async def lifespan(app: FastAPI) -> AsyncGenerator:
|
||||
|
||||
if not INDEXING_ONLY:
|
||||
warm_up_intent_model()
|
||||
if ENABLE_RERANKING_REAL_TIME_FLOW or ENABLE_RERANKING_ASYNC_FLOW:
|
||||
warm_up_cross_encoder()
|
||||
else:
|
||||
logger.info("This model server should only run document indexing.")
|
||||
|
||||
|
Reference in New Issue
Block a user