Cohere Rerank (#2109)

This commit is contained in:
Yuhong Sun
2024-08-11 14:22:42 -07:00
committed by GitHub
parent ce666f3320
commit 386b229ed3
14 changed files with 95 additions and 69 deletions

View File

@ -22,11 +22,10 @@ from model_server.constants import DEFAULT_VERTEX_MODEL
from model_server.constants import DEFAULT_VOYAGE_MODEL
from model_server.constants import EmbeddingModelTextType
from model_server.constants import EmbeddingProvider
from model_server.constants import MODEL_WARM_UP_STRING
from model_server.utils import simple_log_function_time
from shared_configs.configs import DEFAULT_CROSS_ENCODER_MODEL_NAME
from shared_configs.configs import INDEXING_ONLY
from shared_configs.enums import EmbedTextType
from shared_configs.enums import RerankerProvider
from shared_configs.model_server_models import Embedding
from shared_configs.model_server_models import EmbedRequest
from shared_configs.model_server_models import EmbedResponse
@ -226,7 +225,7 @@ def get_embedding_model(
def get_local_reranking_model(
model_name: str = DEFAULT_CROSS_ENCODER_MODEL_NAME,
model_name: str,
) -> CrossEncoder:
global _RERANK_MODEL
if _RERANK_MODEL is None:
@ -236,13 +235,6 @@ def get_local_reranking_model(
return _RERANK_MODEL
def warm_up_cross_encoder() -> None:
logger.info(f"Warming up Cross-Encoder: {DEFAULT_CROSS_ENCODER_MODEL_NAME}")
cross_encoder = get_local_reranking_model()
cross_encoder.predict((MODEL_WARM_UP_STRING, MODEL_WARM_UP_STRING))
@simple_log_function_time()
def embed_text(
texts: list[str],
@ -311,11 +303,21 @@ def embed_text(
@simple_log_function_time()
def calc_sim_scores(query: str, docs: list[str]) -> list[float]:
cross_encoder = get_local_reranking_model()
def local_rerank(query: str, docs: list[str], model_name: str) -> list[float]:
cross_encoder = get_local_reranking_model(model_name)
return cross_encoder.predict([(query, doc) for doc in docs]).tolist() # type: ignore
def cohere_rerank(
query: str, docs: list[str], model_name: str, api_key: str
) -> list[float]:
cohere_client = CohereClient(api_key=api_key)
response = cohere_client.rerank(query=query, documents=docs, model=model_name)
results = response.results
sorted_results = sorted(results, key=lambda item: item.index)
return [result.relevance_score for result in sorted_results]
@router.post("/bi-encoder-embed")
async def process_embed_request(
embed_request: EmbedRequest,
@ -351,23 +353,38 @@ async def process_embed_request(
@router.post("/cross-encoder-scores")
async def process_rerank_request(embed_request: RerankRequest) -> RerankResponse:
async def process_rerank_request(rerank_request: RerankRequest) -> RerankResponse:
"""Cross encoders can be purely black box from the app perspective"""
if INDEXING_ONLY:
raise RuntimeError("Indexing model server should not call intent endpoint")
if not embed_request.documents or not embed_request.query:
if not rerank_request.documents or not rerank_request.query:
raise HTTPException(
status_code=400, detail="Missing documents or query for reranking"
)
if not all(embed_request.documents):
if not all(rerank_request.documents):
raise ValueError("Empty documents cannot be reranked.")
try:
sim_scores = calc_sim_scores(
query=embed_request.query, docs=embed_request.documents
)
return RerankResponse(scores=sim_scores)
if rerank_request.provider_type is None:
sim_scores = local_rerank(
query=rerank_request.query,
docs=rerank_request.documents,
model_name=rerank_request.model_name,
)
return RerankResponse(scores=sim_scores)
elif rerank_request.provider_type == RerankerProvider.COHERE:
if rerank_request.api_key is None:
raise RuntimeError("Cohere Rerank Requires an API Key")
sim_scores = cohere_rerank(
query=rerank_request.query,
docs=rerank_request.documents,
model_name=rerank_request.model_name,
api_key=rerank_request.api_key,
)
return RerankResponse(scores=sim_scores)
else:
raise ValueError(f"Unsupported provider: {rerank_request.provider_type}")
except Exception as e:
logger.exception(f"Error during reranking process:\n{str(e)}")
raise HTTPException(

View File

@ -15,10 +15,7 @@ from danswer.utils.logger import setup_logger
from model_server.custom_models import router as custom_models_router
from model_server.custom_models import warm_up_intent_model
from model_server.encoders import router as encoders_router
from model_server.encoders import warm_up_cross_encoder
from model_server.management_endpoints import router as management_router
from shared_configs.configs import ENABLE_RERANKING_ASYNC_FLOW
from shared_configs.configs import ENABLE_RERANKING_REAL_TIME_FLOW
from shared_configs.configs import INDEXING_ONLY
from shared_configs.configs import MIN_THREADS_ML_MODELS
from shared_configs.configs import MODEL_SERVER_ALLOWED_HOST
@ -64,8 +61,6 @@ async def lifespan(app: FastAPI) -> AsyncGenerator:
if not INDEXING_ONLY:
warm_up_intent_model()
if ENABLE_RERANKING_REAL_TIME_FLOW or ENABLE_RERANKING_ASYNC_FLOW:
warm_up_cross_encoder()
else:
logger.info("This model server should only run document indexing.")