mirror of
https://github.com/danswer-ai/danswer.git
synced 2025-04-03 09:28:25 +02:00
Cohere Rerank (#2109)
This commit is contained in:
parent
ce666f3320
commit
386b229ed3
@ -114,10 +114,11 @@ from danswer.utils.variable_functionality import global_version
|
||||
from danswer.utils.variable_functionality import set_is_ee_based_on_env_variable
|
||||
from shared_configs.configs import DEFAULT_CROSS_ENCODER_API_KEY
|
||||
from shared_configs.configs import DEFAULT_CROSS_ENCODER_MODEL_NAME
|
||||
from shared_configs.configs import ENABLE_RERANKING_ASYNC_FLOW
|
||||
from shared_configs.configs import ENABLE_RERANKING_REAL_TIME_FLOW
|
||||
from shared_configs.configs import DEFAULT_CROSS_ENCODER_PROVIDER_TYPE
|
||||
from shared_configs.configs import DISABLE_RERANK_FOR_STREAMING
|
||||
from shared_configs.configs import MODEL_SERVER_HOST
|
||||
from shared_configs.configs import MODEL_SERVER_PORT
|
||||
from shared_configs.enums import RerankerProvider
|
||||
|
||||
|
||||
logger = setup_logger()
|
||||
@ -288,27 +289,28 @@ async def lifespan(app: FastAPI) -> AsyncGenerator:
|
||||
f"Multilingual query expansion is enabled with {search_settings.multilingual_expansion}."
|
||||
)
|
||||
else:
|
||||
if ENABLE_RERANKING_REAL_TIME_FLOW or ENABLE_RERANKING_ASYNC_FLOW:
|
||||
if DEFAULT_CROSS_ENCODER_MODEL_NAME:
|
||||
logger.info("Reranking is enabled.")
|
||||
if not DEFAULT_CROSS_ENCODER_MODEL_NAME:
|
||||
raise ValueError("No reranking model specified.")
|
||||
|
||||
update_search_settings(
|
||||
SavedSearchSettings(
|
||||
rerank_model_name=DEFAULT_CROSS_ENCODER_MODEL_NAME,
|
||||
api_key=DEFAULT_CROSS_ENCODER_API_KEY,
|
||||
disable_rerank_for_streaming=not ENABLE_RERANKING_REAL_TIME_FLOW,
|
||||
num_rerank=NUM_POSTPROCESSED_RESULTS,
|
||||
multilingual_expansion=[
|
||||
s.strip()
|
||||
for s in MULTILINGUAL_QUERY_EXPANSION.split(",")
|
||||
if s.strip()
|
||||
]
|
||||
if MULTILINGUAL_QUERY_EXPANSION
|
||||
else [],
|
||||
multipass_indexing=ENABLE_MULTIPASS_INDEXING,
|
||||
)
|
||||
update_search_settings(
|
||||
SavedSearchSettings(
|
||||
rerank_model_name=DEFAULT_CROSS_ENCODER_MODEL_NAME,
|
||||
provider_type=RerankerProvider(DEFAULT_CROSS_ENCODER_PROVIDER_TYPE),
|
||||
api_key=DEFAULT_CROSS_ENCODER_API_KEY,
|
||||
disable_rerank_for_streaming=DISABLE_RERANK_FOR_STREAMING,
|
||||
num_rerank=NUM_POSTPROCESSED_RESULTS,
|
||||
multilingual_expansion=[
|
||||
s.strip()
|
||||
for s in MULTILINGUAL_QUERY_EXPANSION.split(",")
|
||||
if s.strip()
|
||||
]
|
||||
if MULTILINGUAL_QUERY_EXPANSION
|
||||
else [],
|
||||
multipass_indexing=ENABLE_MULTIPASS_INDEXING,
|
||||
)
|
||||
)
|
||||
|
||||
logger.info("Verifying query preprocessing (NLTK) data is downloaded")
|
||||
download_nltk_data()
|
||||
|
@ -18,6 +18,7 @@ from shared_configs.configs import MODEL_SERVER_HOST
|
||||
from shared_configs.configs import MODEL_SERVER_PORT
|
||||
from shared_configs.enums import EmbeddingProvider
|
||||
from shared_configs.enums import EmbedTextType
|
||||
from shared_configs.enums import RerankerProvider
|
||||
from shared_configs.model_server_models import Embedding
|
||||
from shared_configs.model_server_models import EmbedRequest
|
||||
from shared_configs.model_server_models import EmbedResponse
|
||||
@ -217,6 +218,7 @@ class RerankingModel:
|
||||
def __init__(
|
||||
self,
|
||||
model_name: str,
|
||||
provider_type: RerankerProvider | None,
|
||||
api_key: str | None,
|
||||
model_server_host: str = MODEL_SERVER_HOST,
|
||||
model_server_port: int = MODEL_SERVER_PORT,
|
||||
@ -224,6 +226,7 @@ class RerankingModel:
|
||||
model_server_url = build_model_server_url(model_server_host, model_server_port)
|
||||
self.rerank_server_endpoint = model_server_url + "/encoder/cross-encoder-scores"
|
||||
self.model_name = model_name
|
||||
self.provider_type = provider_type
|
||||
self.api_key = api_key
|
||||
|
||||
def predict(self, query: str, passages: list[str]) -> list[float]:
|
||||
@ -231,6 +234,7 @@ class RerankingModel:
|
||||
query=query,
|
||||
documents=passages,
|
||||
model_name=self.model_name,
|
||||
provider_type=self.provider_type,
|
||||
api_key=self.api_key,
|
||||
)
|
||||
|
||||
|
@ -13,6 +13,7 @@ from danswer.indexing.models import BaseChunk
|
||||
from danswer.search.enums import LLMEvaluationType
|
||||
from danswer.search.enums import OptionalSearchSetting
|
||||
from danswer.search.enums import SearchType
|
||||
from shared_configs.enums import RerankerProvider
|
||||
|
||||
|
||||
MAX_METRICS_CONTENT = (
|
||||
@ -21,10 +22,11 @@ MAX_METRICS_CONTENT = (
|
||||
|
||||
|
||||
class RerankingDetails(BaseModel):
|
||||
rerank_model_name: str
|
||||
# If model is None (or num_rerank is 0), then reranking is turned off
|
||||
rerank_model_name: str | None
|
||||
provider_type: RerankerProvider | None
|
||||
api_key: str | None
|
||||
|
||||
# Set to 0 to disable reranking explicitly
|
||||
num_rerank: int
|
||||
|
||||
|
||||
@ -41,6 +43,7 @@ class SavedSearchSettings(RerankingDetails):
|
||||
def to_reranking_detail(self) -> RerankingDetails:
|
||||
return RerankingDetails(
|
||||
rerank_model_name=self.rerank_model_name,
|
||||
provider_type=self.provider_type,
|
||||
api_key=self.api_key,
|
||||
num_rerank=self.num_rerank,
|
||||
)
|
||||
|
@ -90,14 +90,15 @@ def semantic_reranking(
|
||||
"""
|
||||
rerank_settings = query.rerank_settings
|
||||
|
||||
if not rerank_settings:
|
||||
if not rerank_settings or not rerank_settings.rerank_model_name:
|
||||
# Should never reach this part of the flow without reranking settings
|
||||
raise RuntimeError("Reranking settings not found")
|
||||
raise RuntimeError("Reranking flow should not be running")
|
||||
|
||||
chunks_to_rerank = chunks[: rerank_settings.num_rerank]
|
||||
|
||||
cross_encoder = RerankingModel(
|
||||
model_name=rerank_settings.rerank_model_name,
|
||||
provider_type=rerank_settings.provider_type,
|
||||
api_key=rerank_settings.api_key,
|
||||
)
|
||||
|
||||
@ -258,7 +259,11 @@ def search_postprocessing(
|
||||
|
||||
rerank_task_id = None
|
||||
sections_yielded = False
|
||||
if search_query.rerank_settings:
|
||||
if (
|
||||
search_query.rerank_settings
|
||||
and search_query.rerank_settings.rerank_model_name
|
||||
and search_query.rerank_settings.num_rerank > 0
|
||||
):
|
||||
post_processing_tasks.append(
|
||||
FunctionCall(
|
||||
rerank_sections,
|
||||
|
@ -22,11 +22,10 @@ from model_server.constants import DEFAULT_VERTEX_MODEL
|
||||
from model_server.constants import DEFAULT_VOYAGE_MODEL
|
||||
from model_server.constants import EmbeddingModelTextType
|
||||
from model_server.constants import EmbeddingProvider
|
||||
from model_server.constants import MODEL_WARM_UP_STRING
|
||||
from model_server.utils import simple_log_function_time
|
||||
from shared_configs.configs import DEFAULT_CROSS_ENCODER_MODEL_NAME
|
||||
from shared_configs.configs import INDEXING_ONLY
|
||||
from shared_configs.enums import EmbedTextType
|
||||
from shared_configs.enums import RerankerProvider
|
||||
from shared_configs.model_server_models import Embedding
|
||||
from shared_configs.model_server_models import EmbedRequest
|
||||
from shared_configs.model_server_models import EmbedResponse
|
||||
@ -226,7 +225,7 @@ def get_embedding_model(
|
||||
|
||||
|
||||
def get_local_reranking_model(
|
||||
model_name: str = DEFAULT_CROSS_ENCODER_MODEL_NAME,
|
||||
model_name: str,
|
||||
) -> CrossEncoder:
|
||||
global _RERANK_MODEL
|
||||
if _RERANK_MODEL is None:
|
||||
@ -236,13 +235,6 @@ def get_local_reranking_model(
|
||||
return _RERANK_MODEL
|
||||
|
||||
|
||||
def warm_up_cross_encoder() -> None:
|
||||
logger.info(f"Warming up Cross-Encoder: {DEFAULT_CROSS_ENCODER_MODEL_NAME}")
|
||||
|
||||
cross_encoder = get_local_reranking_model()
|
||||
cross_encoder.predict((MODEL_WARM_UP_STRING, MODEL_WARM_UP_STRING))
|
||||
|
||||
|
||||
@simple_log_function_time()
|
||||
def embed_text(
|
||||
texts: list[str],
|
||||
@ -311,11 +303,21 @@ def embed_text(
|
||||
|
||||
|
||||
@simple_log_function_time()
|
||||
def calc_sim_scores(query: str, docs: list[str]) -> list[float]:
|
||||
cross_encoder = get_local_reranking_model()
|
||||
def local_rerank(query: str, docs: list[str], model_name: str) -> list[float]:
|
||||
cross_encoder = get_local_reranking_model(model_name)
|
||||
return cross_encoder.predict([(query, doc) for doc in docs]).tolist() # type: ignore
|
||||
|
||||
|
||||
def cohere_rerank(
|
||||
query: str, docs: list[str], model_name: str, api_key: str
|
||||
) -> list[float]:
|
||||
cohere_client = CohereClient(api_key=api_key)
|
||||
response = cohere_client.rerank(query=query, documents=docs, model=model_name)
|
||||
results = response.results
|
||||
sorted_results = sorted(results, key=lambda item: item.index)
|
||||
return [result.relevance_score for result in sorted_results]
|
||||
|
||||
|
||||
@router.post("/bi-encoder-embed")
|
||||
async def process_embed_request(
|
||||
embed_request: EmbedRequest,
|
||||
@ -351,23 +353,38 @@ async def process_embed_request(
|
||||
|
||||
|
||||
@router.post("/cross-encoder-scores")
|
||||
async def process_rerank_request(embed_request: RerankRequest) -> RerankResponse:
|
||||
async def process_rerank_request(rerank_request: RerankRequest) -> RerankResponse:
|
||||
"""Cross encoders can be purely black box from the app perspective"""
|
||||
if INDEXING_ONLY:
|
||||
raise RuntimeError("Indexing model server should not call intent endpoint")
|
||||
|
||||
if not embed_request.documents or not embed_request.query:
|
||||
if not rerank_request.documents or not rerank_request.query:
|
||||
raise HTTPException(
|
||||
status_code=400, detail="Missing documents or query for reranking"
|
||||
)
|
||||
if not all(embed_request.documents):
|
||||
if not all(rerank_request.documents):
|
||||
raise ValueError("Empty documents cannot be reranked.")
|
||||
|
||||
try:
|
||||
sim_scores = calc_sim_scores(
|
||||
query=embed_request.query, docs=embed_request.documents
|
||||
)
|
||||
return RerankResponse(scores=sim_scores)
|
||||
if rerank_request.provider_type is None:
|
||||
sim_scores = local_rerank(
|
||||
query=rerank_request.query,
|
||||
docs=rerank_request.documents,
|
||||
model_name=rerank_request.model_name,
|
||||
)
|
||||
return RerankResponse(scores=sim_scores)
|
||||
elif rerank_request.provider_type == RerankerProvider.COHERE:
|
||||
if rerank_request.api_key is None:
|
||||
raise RuntimeError("Cohere Rerank Requires an API Key")
|
||||
sim_scores = cohere_rerank(
|
||||
query=rerank_request.query,
|
||||
docs=rerank_request.documents,
|
||||
model_name=rerank_request.model_name,
|
||||
api_key=rerank_request.api_key,
|
||||
)
|
||||
return RerankResponse(scores=sim_scores)
|
||||
else:
|
||||
raise ValueError(f"Unsupported provider: {rerank_request.provider_type}")
|
||||
except Exception as e:
|
||||
logger.exception(f"Error during reranking process:\n{str(e)}")
|
||||
raise HTTPException(
|
||||
|
@ -15,10 +15,7 @@ from danswer.utils.logger import setup_logger
|
||||
from model_server.custom_models import router as custom_models_router
|
||||
from model_server.custom_models import warm_up_intent_model
|
||||
from model_server.encoders import router as encoders_router
|
||||
from model_server.encoders import warm_up_cross_encoder
|
||||
from model_server.management_endpoints import router as management_router
|
||||
from shared_configs.configs import ENABLE_RERANKING_ASYNC_FLOW
|
||||
from shared_configs.configs import ENABLE_RERANKING_REAL_TIME_FLOW
|
||||
from shared_configs.configs import INDEXING_ONLY
|
||||
from shared_configs.configs import MIN_THREADS_ML_MODELS
|
||||
from shared_configs.configs import MODEL_SERVER_ALLOWED_HOST
|
||||
@ -64,8 +61,6 @@ async def lifespan(app: FastAPI) -> AsyncGenerator:
|
||||
|
||||
if not INDEXING_ONLY:
|
||||
warm_up_intent_model()
|
||||
if ENABLE_RERANKING_REAL_TIME_FLOW or ENABLE_RERANKING_ASYNC_FLOW:
|
||||
warm_up_cross_encoder()
|
||||
else:
|
||||
logger.info("This model server should only run document indexing.")
|
||||
|
||||
|
@ -20,17 +20,20 @@ INTENT_MODEL_TAG = "v1.0.3"
|
||||
# Bi-Encoder, other details
|
||||
DOC_EMBEDDING_CONTEXT_SIZE = 512
|
||||
|
||||
# Cross Encoder Settings
|
||||
ENABLE_RERANKING_ASYNC_FLOW = (
|
||||
os.environ.get("ENABLE_RERANKING_ASYNC_FLOW", "").lower() == "true"
|
||||
)
|
||||
ENABLE_RERANKING_REAL_TIME_FLOW = (
|
||||
os.environ.get("ENABLE_RERANKING_REAL_TIME_FLOW", "").lower() == "true"
|
||||
)
|
||||
|
||||
# Used for loading defaults for automatic deployments and dev flows
|
||||
DEFAULT_CROSS_ENCODER_MODEL_NAME = "mixedbread-ai/mxbai-rerank-xsmall-v1"
|
||||
DEFAULT_CROSS_ENCODER_API_KEY = os.environ.get("DEFAULT_CROSS_ENCODER_API_KEY")
|
||||
# For local, use: mixedbread-ai/mxbai-rerank-xsmall-v1
|
||||
DEFAULT_CROSS_ENCODER_MODEL_NAME = (
|
||||
os.environ.get("DEFAULT_CROSS_ENCODER_MODEL_NAME") or None
|
||||
)
|
||||
DEFAULT_CROSS_ENCODER_API_KEY = os.environ.get("DEFAULT_CROSS_ENCODER_API_KEY") or None
|
||||
DEFAULT_CROSS_ENCODER_PROVIDER_TYPE = (
|
||||
os.environ.get("DEFAULT_CROSS_ENCODER_PROVIDER_TYPE") or None
|
||||
)
|
||||
DISABLE_RERANK_FOR_STREAMING = (
|
||||
os.environ.get("DISABLE_RERANK_FOR_STREAMING", "").lower() == "true"
|
||||
)
|
||||
|
||||
|
||||
# This controls the minimum number of pytorch "threads" to allocate to the embedding
|
||||
# model. If torch finds more threads on its own, this value is not used.
|
||||
|
@ -8,6 +8,10 @@ class EmbeddingProvider(str, Enum):
|
||||
GOOGLE = "google"
|
||||
|
||||
|
||||
class RerankerProvider(str, Enum):
|
||||
COHERE = "cohere"
|
||||
|
||||
|
||||
class EmbedTextType(str, Enum):
|
||||
QUERY = "query"
|
||||
PASSAGE = "passage"
|
||||
|
@ -2,6 +2,7 @@ from pydantic import BaseModel
|
||||
|
||||
from shared_configs.enums import EmbeddingProvider
|
||||
from shared_configs.enums import EmbedTextType
|
||||
from shared_configs.enums import RerankerProvider
|
||||
|
||||
Embedding = list[float]
|
||||
|
||||
@ -27,6 +28,7 @@ class RerankRequest(BaseModel):
|
||||
query: str
|
||||
documents: list[str]
|
||||
model_name: str
|
||||
provider_type: RerankerProvider | None
|
||||
api_key: str | None
|
||||
|
||||
|
||||
|
@ -74,8 +74,7 @@ services:
|
||||
- DOC_EMBEDDING_DIM=${DOC_EMBEDDING_DIM:-}
|
||||
- NORMALIZE_EMBEDDINGS=${NORMALIZE_EMBEDDINGS:-}
|
||||
- ASYM_QUERY_PREFIX=${ASYM_QUERY_PREFIX:-}
|
||||
- ENABLE_RERANKING_REAL_TIME_FLOW=${ENABLE_RERANKING_REAL_TIME_FLOW:-}
|
||||
- ENABLE_RERANKING_ASYNC_FLOW=${ENABLE_RERANKING_ASYNC_FLOW:-}
|
||||
- DISABLE_RERANK_FOR_STREAMING=${DISABLE_RERANK_FOR_STREAMING:-}
|
||||
- MODEL_SERVER_HOST=${MODEL_SERVER_HOST:-inference_model_server}
|
||||
- MODEL_SERVER_PORT=${MODEL_SERVER_PORT:-}
|
||||
# Leave this on pretty please? Nothing sensitive is collected!
|
||||
|
@ -70,8 +70,7 @@ services:
|
||||
- DOC_EMBEDDING_DIM=${DOC_EMBEDDING_DIM:-}
|
||||
- NORMALIZE_EMBEDDINGS=${NORMALIZE_EMBEDDINGS:-}
|
||||
- ASYM_QUERY_PREFIX=${ASYM_QUERY_PREFIX:-}
|
||||
- ENABLE_RERANKING_REAL_TIME_FLOW=${ENABLE_RERANKING_REAL_TIME_FLOW:-}
|
||||
- ENABLE_RERANKING_ASYNC_FLOW=${ENABLE_RERANKING_ASYNC_FLOW:-}
|
||||
- DISABLE_RERANK_FOR_STREAMING=${DISABLE_RERANK_FOR_STREAMING:-}
|
||||
- MODEL_SERVER_HOST=${MODEL_SERVER_HOST:-inference_model_server}
|
||||
- MODEL_SERVER_PORT=${MODEL_SERVER_PORT:-}
|
||||
# Leave this on pretty please? Nothing sensitive is collected!
|
||||
|
@ -26,11 +26,6 @@ NORMALIZE_EMBEDDINGS="True"
|
||||
# If using a common language like Spanish, French, Chinese, etc. this can be kept turned on
|
||||
DISABLE_LLM_DOC_RELEVANCE="True"
|
||||
|
||||
# The default reranking models are English first
|
||||
# There are no great quality French/English reranking models currently so turning this off
|
||||
ENABLE_RERANKING_ASYNC_FLOW="False"
|
||||
ENABLE_RERANKING_REAL_TIME_FLOW="False"
|
||||
|
||||
# Enables fine-grained embeddings for better retrieval
|
||||
# At the cost of indexing speed (~5x slower), query time is same speed
|
||||
# Since reranking is turned off and multilingual retrieval is generally harder
|
||||
|
@ -420,8 +420,7 @@ configMap:
|
||||
NORMALIZE_EMBEDDINGS: ""
|
||||
ASYM_QUERY_PREFIX: ""
|
||||
ASYM_PASSAGE_PREFIX: ""
|
||||
ENABLE_RERANKING_REAL_TIME_FLOW: ""
|
||||
ENABLE_RERANKING_ASYNC_FLOW: ""
|
||||
DISABLE_RERANK_FOR_STREAMING: ""
|
||||
MODEL_SERVER_PORT: ""
|
||||
MIN_THREADS_ML_MODELS: ""
|
||||
# Indexing Configs
|
||||
|
@ -45,8 +45,7 @@ data:
|
||||
NORMALIZE_EMBEDDINGS: ""
|
||||
ASYM_QUERY_PREFIX: ""
|
||||
ASYM_PASSAGE_PREFIX: ""
|
||||
ENABLE_RERANKING_REAL_TIME_FLOW: ""
|
||||
ENABLE_RERANKING_ASYNC_FLOW: ""
|
||||
DISABLE_RERANK_FOR_STREAMING: ""
|
||||
MODEL_SERVER_HOST: "inference-model-server-service"
|
||||
MODEL_SERVER_PORT: ""
|
||||
INDEXING_MODEL_SERVER_HOST: "indexing-model-server-service"
|
||||
|
Loading…
x
Reference in New Issue
Block a user