Cohere Rerank (#2109)

This commit is contained in:
Yuhong Sun 2024-08-11 14:22:42 -07:00 committed by GitHub
parent ce666f3320
commit 386b229ed3
No known key found for this signature in database
GPG Key ID: B5690EEEBB952194
14 changed files with 95 additions and 69 deletions

View File

@ -114,10 +114,11 @@ from danswer.utils.variable_functionality import global_version
from danswer.utils.variable_functionality import set_is_ee_based_on_env_variable
from shared_configs.configs import DEFAULT_CROSS_ENCODER_API_KEY
from shared_configs.configs import DEFAULT_CROSS_ENCODER_MODEL_NAME
from shared_configs.configs import ENABLE_RERANKING_ASYNC_FLOW
from shared_configs.configs import ENABLE_RERANKING_REAL_TIME_FLOW
from shared_configs.configs import DEFAULT_CROSS_ENCODER_PROVIDER_TYPE
from shared_configs.configs import DISABLE_RERANK_FOR_STREAMING
from shared_configs.configs import MODEL_SERVER_HOST
from shared_configs.configs import MODEL_SERVER_PORT
from shared_configs.enums import RerankerProvider
logger = setup_logger()
@ -288,27 +289,28 @@ async def lifespan(app: FastAPI) -> AsyncGenerator:
f"Multilingual query expansion is enabled with {search_settings.multilingual_expansion}."
)
else:
if ENABLE_RERANKING_REAL_TIME_FLOW or ENABLE_RERANKING_ASYNC_FLOW:
if DEFAULT_CROSS_ENCODER_MODEL_NAME:
logger.info("Reranking is enabled.")
if not DEFAULT_CROSS_ENCODER_MODEL_NAME:
raise ValueError("No reranking model specified.")
update_search_settings(
SavedSearchSettings(
rerank_model_name=DEFAULT_CROSS_ENCODER_MODEL_NAME,
api_key=DEFAULT_CROSS_ENCODER_API_KEY,
disable_rerank_for_streaming=not ENABLE_RERANKING_REAL_TIME_FLOW,
num_rerank=NUM_POSTPROCESSED_RESULTS,
multilingual_expansion=[
s.strip()
for s in MULTILINGUAL_QUERY_EXPANSION.split(",")
if s.strip()
]
if MULTILINGUAL_QUERY_EXPANSION
else [],
multipass_indexing=ENABLE_MULTIPASS_INDEXING,
)
update_search_settings(
SavedSearchSettings(
rerank_model_name=DEFAULT_CROSS_ENCODER_MODEL_NAME,
provider_type=RerankerProvider(DEFAULT_CROSS_ENCODER_PROVIDER_TYPE),
api_key=DEFAULT_CROSS_ENCODER_API_KEY,
disable_rerank_for_streaming=DISABLE_RERANK_FOR_STREAMING,
num_rerank=NUM_POSTPROCESSED_RESULTS,
multilingual_expansion=[
s.strip()
for s in MULTILINGUAL_QUERY_EXPANSION.split(",")
if s.strip()
]
if MULTILINGUAL_QUERY_EXPANSION
else [],
multipass_indexing=ENABLE_MULTIPASS_INDEXING,
)
)
logger.info("Verifying query preprocessing (NLTK) data is downloaded")
download_nltk_data()

View File

@ -18,6 +18,7 @@ from shared_configs.configs import MODEL_SERVER_HOST
from shared_configs.configs import MODEL_SERVER_PORT
from shared_configs.enums import EmbeddingProvider
from shared_configs.enums import EmbedTextType
from shared_configs.enums import RerankerProvider
from shared_configs.model_server_models import Embedding
from shared_configs.model_server_models import EmbedRequest
from shared_configs.model_server_models import EmbedResponse
@ -217,6 +218,7 @@ class RerankingModel:
def __init__(
self,
model_name: str,
provider_type: RerankerProvider | None,
api_key: str | None,
model_server_host: str = MODEL_SERVER_HOST,
model_server_port: int = MODEL_SERVER_PORT,
@ -224,6 +226,7 @@ class RerankingModel:
model_server_url = build_model_server_url(model_server_host, model_server_port)
self.rerank_server_endpoint = model_server_url + "/encoder/cross-encoder-scores"
self.model_name = model_name
self.provider_type = provider_type
self.api_key = api_key
def predict(self, query: str, passages: list[str]) -> list[float]:
@ -231,6 +234,7 @@ class RerankingModel:
query=query,
documents=passages,
model_name=self.model_name,
provider_type=self.provider_type,
api_key=self.api_key,
)

View File

@ -13,6 +13,7 @@ from danswer.indexing.models import BaseChunk
from danswer.search.enums import LLMEvaluationType
from danswer.search.enums import OptionalSearchSetting
from danswer.search.enums import SearchType
from shared_configs.enums import RerankerProvider
MAX_METRICS_CONTENT = (
@ -21,10 +22,11 @@ MAX_METRICS_CONTENT = (
class RerankingDetails(BaseModel):
rerank_model_name: str
# If model is None (or num_rerank is 0), then reranking is turned off
rerank_model_name: str | None
provider_type: RerankerProvider | None
api_key: str | None
# Set to 0 to disable reranking explicitly
num_rerank: int
@ -41,6 +43,7 @@ class SavedSearchSettings(RerankingDetails):
def to_reranking_detail(self) -> RerankingDetails:
return RerankingDetails(
rerank_model_name=self.rerank_model_name,
provider_type=self.provider_type,
api_key=self.api_key,
num_rerank=self.num_rerank,
)

View File

@ -90,14 +90,15 @@ def semantic_reranking(
"""
rerank_settings = query.rerank_settings
if not rerank_settings:
if not rerank_settings or not rerank_settings.rerank_model_name:
# Should never reach this part of the flow without reranking settings
raise RuntimeError("Reranking settings not found")
raise RuntimeError("Reranking flow should not be running")
chunks_to_rerank = chunks[: rerank_settings.num_rerank]
cross_encoder = RerankingModel(
model_name=rerank_settings.rerank_model_name,
provider_type=rerank_settings.provider_type,
api_key=rerank_settings.api_key,
)
@ -258,7 +259,11 @@ def search_postprocessing(
rerank_task_id = None
sections_yielded = False
if search_query.rerank_settings:
if (
search_query.rerank_settings
and search_query.rerank_settings.rerank_model_name
and search_query.rerank_settings.num_rerank > 0
):
post_processing_tasks.append(
FunctionCall(
rerank_sections,

View File

@ -22,11 +22,10 @@ from model_server.constants import DEFAULT_VERTEX_MODEL
from model_server.constants import DEFAULT_VOYAGE_MODEL
from model_server.constants import EmbeddingModelTextType
from model_server.constants import EmbeddingProvider
from model_server.constants import MODEL_WARM_UP_STRING
from model_server.utils import simple_log_function_time
from shared_configs.configs import DEFAULT_CROSS_ENCODER_MODEL_NAME
from shared_configs.configs import INDEXING_ONLY
from shared_configs.enums import EmbedTextType
from shared_configs.enums import RerankerProvider
from shared_configs.model_server_models import Embedding
from shared_configs.model_server_models import EmbedRequest
from shared_configs.model_server_models import EmbedResponse
@ -226,7 +225,7 @@ def get_embedding_model(
def get_local_reranking_model(
model_name: str = DEFAULT_CROSS_ENCODER_MODEL_NAME,
model_name: str,
) -> CrossEncoder:
global _RERANK_MODEL
if _RERANK_MODEL is None:
@ -236,13 +235,6 @@ def get_local_reranking_model(
return _RERANK_MODEL
def warm_up_cross_encoder() -> None:
logger.info(f"Warming up Cross-Encoder: {DEFAULT_CROSS_ENCODER_MODEL_NAME}")
cross_encoder = get_local_reranking_model()
cross_encoder.predict((MODEL_WARM_UP_STRING, MODEL_WARM_UP_STRING))
@simple_log_function_time()
def embed_text(
texts: list[str],
@ -311,11 +303,21 @@ def embed_text(
@simple_log_function_time()
def calc_sim_scores(query: str, docs: list[str]) -> list[float]:
cross_encoder = get_local_reranking_model()
def local_rerank(query: str, docs: list[str], model_name: str) -> list[float]:
cross_encoder = get_local_reranking_model(model_name)
return cross_encoder.predict([(query, doc) for doc in docs]).tolist() # type: ignore
def cohere_rerank(
query: str, docs: list[str], model_name: str, api_key: str
) -> list[float]:
cohere_client = CohereClient(api_key=api_key)
response = cohere_client.rerank(query=query, documents=docs, model=model_name)
results = response.results
sorted_results = sorted(results, key=lambda item: item.index)
return [result.relevance_score for result in sorted_results]
@router.post("/bi-encoder-embed")
async def process_embed_request(
embed_request: EmbedRequest,
@ -351,23 +353,38 @@ async def process_embed_request(
@router.post("/cross-encoder-scores")
async def process_rerank_request(embed_request: RerankRequest) -> RerankResponse:
async def process_rerank_request(rerank_request: RerankRequest) -> RerankResponse:
"""Cross encoders can be purely black box from the app perspective"""
if INDEXING_ONLY:
raise RuntimeError("Indexing model server should not call intent endpoint")
if not embed_request.documents or not embed_request.query:
if not rerank_request.documents or not rerank_request.query:
raise HTTPException(
status_code=400, detail="Missing documents or query for reranking"
)
if not all(embed_request.documents):
if not all(rerank_request.documents):
raise ValueError("Empty documents cannot be reranked.")
try:
sim_scores = calc_sim_scores(
query=embed_request.query, docs=embed_request.documents
)
return RerankResponse(scores=sim_scores)
if rerank_request.provider_type is None:
sim_scores = local_rerank(
query=rerank_request.query,
docs=rerank_request.documents,
model_name=rerank_request.model_name,
)
return RerankResponse(scores=sim_scores)
elif rerank_request.provider_type == RerankerProvider.COHERE:
if rerank_request.api_key is None:
raise RuntimeError("Cohere Rerank Requires an API Key")
sim_scores = cohere_rerank(
query=rerank_request.query,
docs=rerank_request.documents,
model_name=rerank_request.model_name,
api_key=rerank_request.api_key,
)
return RerankResponse(scores=sim_scores)
else:
raise ValueError(f"Unsupported provider: {rerank_request.provider_type}")
except Exception as e:
logger.exception(f"Error during reranking process:\n{str(e)}")
raise HTTPException(

View File

@ -15,10 +15,7 @@ from danswer.utils.logger import setup_logger
from model_server.custom_models import router as custom_models_router
from model_server.custom_models import warm_up_intent_model
from model_server.encoders import router as encoders_router
from model_server.encoders import warm_up_cross_encoder
from model_server.management_endpoints import router as management_router
from shared_configs.configs import ENABLE_RERANKING_ASYNC_FLOW
from shared_configs.configs import ENABLE_RERANKING_REAL_TIME_FLOW
from shared_configs.configs import INDEXING_ONLY
from shared_configs.configs import MIN_THREADS_ML_MODELS
from shared_configs.configs import MODEL_SERVER_ALLOWED_HOST
@ -64,8 +61,6 @@ async def lifespan(app: FastAPI) -> AsyncGenerator:
if not INDEXING_ONLY:
warm_up_intent_model()
if ENABLE_RERANKING_REAL_TIME_FLOW or ENABLE_RERANKING_ASYNC_FLOW:
warm_up_cross_encoder()
else:
logger.info("This model server should only run document indexing.")

View File

@ -20,17 +20,20 @@ INTENT_MODEL_TAG = "v1.0.3"
# Bi-Encoder, other details
DOC_EMBEDDING_CONTEXT_SIZE = 512
# Cross Encoder Settings
ENABLE_RERANKING_ASYNC_FLOW = (
os.environ.get("ENABLE_RERANKING_ASYNC_FLOW", "").lower() == "true"
)
ENABLE_RERANKING_REAL_TIME_FLOW = (
os.environ.get("ENABLE_RERANKING_REAL_TIME_FLOW", "").lower() == "true"
)
# Used for loading defaults for automatic deployments and dev flows
DEFAULT_CROSS_ENCODER_MODEL_NAME = "mixedbread-ai/mxbai-rerank-xsmall-v1"
DEFAULT_CROSS_ENCODER_API_KEY = os.environ.get("DEFAULT_CROSS_ENCODER_API_KEY")
# For local, use: mixedbread-ai/mxbai-rerank-xsmall-v1
DEFAULT_CROSS_ENCODER_MODEL_NAME = (
os.environ.get("DEFAULT_CROSS_ENCODER_MODEL_NAME") or None
)
DEFAULT_CROSS_ENCODER_API_KEY = os.environ.get("DEFAULT_CROSS_ENCODER_API_KEY") or None
DEFAULT_CROSS_ENCODER_PROVIDER_TYPE = (
os.environ.get("DEFAULT_CROSS_ENCODER_PROVIDER_TYPE") or None
)
DISABLE_RERANK_FOR_STREAMING = (
os.environ.get("DISABLE_RERANK_FOR_STREAMING", "").lower() == "true"
)
# This controls the minimum number of pytorch "threads" to allocate to the embedding
# model. If torch finds more threads on its own, this value is not used.

View File

@ -8,6 +8,10 @@ class EmbeddingProvider(str, Enum):
GOOGLE = "google"
class RerankerProvider(str, Enum):
COHERE = "cohere"
class EmbedTextType(str, Enum):
QUERY = "query"
PASSAGE = "passage"

View File

@ -2,6 +2,7 @@ from pydantic import BaseModel
from shared_configs.enums import EmbeddingProvider
from shared_configs.enums import EmbedTextType
from shared_configs.enums import RerankerProvider
Embedding = list[float]
@ -27,6 +28,7 @@ class RerankRequest(BaseModel):
query: str
documents: list[str]
model_name: str
provider_type: RerankerProvider | None
api_key: str | None

View File

@ -74,8 +74,7 @@ services:
- DOC_EMBEDDING_DIM=${DOC_EMBEDDING_DIM:-}
- NORMALIZE_EMBEDDINGS=${NORMALIZE_EMBEDDINGS:-}
- ASYM_QUERY_PREFIX=${ASYM_QUERY_PREFIX:-}
- ENABLE_RERANKING_REAL_TIME_FLOW=${ENABLE_RERANKING_REAL_TIME_FLOW:-}
- ENABLE_RERANKING_ASYNC_FLOW=${ENABLE_RERANKING_ASYNC_FLOW:-}
- DISABLE_RERANK_FOR_STREAMING=${DISABLE_RERANK_FOR_STREAMING:-}
- MODEL_SERVER_HOST=${MODEL_SERVER_HOST:-inference_model_server}
- MODEL_SERVER_PORT=${MODEL_SERVER_PORT:-}
# Leave this on pretty please? Nothing sensitive is collected!

View File

@ -70,8 +70,7 @@ services:
- DOC_EMBEDDING_DIM=${DOC_EMBEDDING_DIM:-}
- NORMALIZE_EMBEDDINGS=${NORMALIZE_EMBEDDINGS:-}
- ASYM_QUERY_PREFIX=${ASYM_QUERY_PREFIX:-}
- ENABLE_RERANKING_REAL_TIME_FLOW=${ENABLE_RERANKING_REAL_TIME_FLOW:-}
- ENABLE_RERANKING_ASYNC_FLOW=${ENABLE_RERANKING_ASYNC_FLOW:-}
- DISABLE_RERANK_FOR_STREAMING=${DISABLE_RERANK_FOR_STREAMING:-}
- MODEL_SERVER_HOST=${MODEL_SERVER_HOST:-inference_model_server}
- MODEL_SERVER_PORT=${MODEL_SERVER_PORT:-}
# Leave this on pretty please? Nothing sensitive is collected!

View File

@ -26,11 +26,6 @@ NORMALIZE_EMBEDDINGS="True"
# If using a common language like Spanish, French, Chinese, etc. this can be kept turned on
DISABLE_LLM_DOC_RELEVANCE="True"
# The default reranking models are English first
# There are no great quality French/English reranking models currently so turning this off
ENABLE_RERANKING_ASYNC_FLOW="False"
ENABLE_RERANKING_REAL_TIME_FLOW="False"
# Enables fine-grained embeddings for better retrieval
# At the cost of indexing speed (~5x slower), query time is same speed
# Since reranking is turned off and multilingual retrieval is generally harder

View File

@ -420,8 +420,7 @@ configMap:
NORMALIZE_EMBEDDINGS: ""
ASYM_QUERY_PREFIX: ""
ASYM_PASSAGE_PREFIX: ""
ENABLE_RERANKING_REAL_TIME_FLOW: ""
ENABLE_RERANKING_ASYNC_FLOW: ""
DISABLE_RERANK_FOR_STREAMING: ""
MODEL_SERVER_PORT: ""
MIN_THREADS_ML_MODELS: ""
# Indexing Configs

View File

@ -45,8 +45,7 @@ data:
NORMALIZE_EMBEDDINGS: ""
ASYM_QUERY_PREFIX: ""
ASYM_PASSAGE_PREFIX: ""
ENABLE_RERANKING_REAL_TIME_FLOW: ""
ENABLE_RERANKING_ASYNC_FLOW: ""
DISABLE_RERANK_FOR_STREAMING: ""
MODEL_SERVER_HOST: "inference-model-server-service"
MODEL_SERVER_PORT: ""
INDEXING_MODEL_SERVER_HOST: "indexing-model-server-service"